Skip to content

Commit 827072c

Browse files
committed
Training model
1 parent ad48977 commit 827072c

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ ProtoStructs = "437b6fc4-8e8e-11e9-3fa1-ad391e66c018"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
1616
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
17+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

1819
[compat]
1920
ADTypes = "1.11.0"
@@ -26,6 +27,7 @@ ProtoStructs = "1.2.1"
2627
Random = "1.11.0"
2728
SymbolicUtils = "3.7.2"
2829
Symbolics = "6.22.0"
30+
Zygote = "0.6.73"
2931
julia = "1.10"
3032

3133
[extras]

scripts/sample.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,23 @@ model = HybridModel(
2424

2525
model(rand(Float32, 5), (ps, [1.2f0]), st)
2626
model(rand(Float32, 5,5), (ps, [1.2f0]), st)
27+
28+
# "Training"
29+
using Optimisers
30+
using NNlib
31+
using ADTypes
32+
using Zygote
33+
34+
X = rand(Float32, 5, 5)
35+
y = rand(Float32, 5)
36+
37+
loss(X, y, ps, globals, st) = sum((y .- model(X, (ps, globals), st)) .^ 2)
38+
39+
globals_init = [1.0f0]
40+
41+
grads = gradient((ps, globals) -> loss(X, y, ps, globals, st), ps, globals_init)
42+
43+
opt = Optimisers.setup(AdamW(), (ps, globals_init))
44+
45+
Optimisers.update!(opt, (ps, globals_init), grads)
46+

0 commit comments

Comments
 (0)