Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,5 @@ It is defined in Functors.jl and re-exported by Optimisers.jl here for convenien
Functors.KeyPath
Functors.haskeypath
Functors.getkeypath
Functors.setkeypath!
```
6 changes: 6 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ function update(tree, model, grad, higher...)
update!(t′, x′, grad, higher...)
end

function update!(::AbstractRule, model, grad, higher...)
error("""update! must be called with an optimiser state, not a rule.
Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`.
""")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
error("""update! must be called with an optimiser state, not a rule.
Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`.
""")
throw(ArgumentError("""update! must be called with an optimiser state tree, not a rule.
Call `opt_state = setup(rule, model)` first, then `update!(opt_state, model, grad)`.
"""))

To better match language used elsewhere in the docs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test needs to change to match, I think.

end

function update!(tree, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ end
@test isnan(m3n.γ[3])
end

@testset "friendly error when using rule instead of state" begin
@test_throws ErrorException Optimisers.update!(Adam(), rand(2), rand(2))
end

@testset "Dict support" begin
@testset "simple dict" begin
d = Dict(:a => [1.0,2.0], :b => [3.0,4.0], :c => 1)
Expand Down
Loading