Skip to content

Commit be3c943

Browse files
authored
make the example run (#37)
1 parent 5b6b380 commit be3c943

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

docs/src/index.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,19 @@ to adjust the model:
4040

4141
using Flux, Metalhead, Optimisers
4242

43-
o = Optimisers.ADAM() # define an ADAM optimiser with default settings
44-
st = Optimisers.setup(o, m) # initialize the optimiser before using it
43+
model = Metalhead.ResNet18() # define a model to train on
44+
image = rand(Float32, 224, 224, 3, 1); # dummy data
45+
@show sum(model(image)); # dummy loss function
4546

46-
model = ResNet18() # define a model to train on
47-
ip = rand(Float32, 224, 224, 3, 1) # dummy data
47+
o = Optimisers.ADAM() # define an ADAM optimiser with default settings
48+
st = Optimisers.setup(o, model); # initialize the optimiser before using it
4849

49-
m̄, _ = gradient(model, ip) do m, x # calculate the gradients
50-
sum(m(x)) # dummy loss function
51-
end
50+
m̄, _ = gradient(model, image) do m, x # calculate the gradients
51+
sum(m(x))
52+
end;
5253

53-
st, mnew = Optimisers.update(st, m, m̄)
54+
st, model = Optimisers.update(st, model, m̄);
55+
@show sum(model(image));
5456

5557
```
5658

0 commit comments

Comments
 (0)