Skip to content
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,46 @@ function Zygote._pullback(ctx::Zygote.AContext, ::typeof(checkpointed), f, xs...
return y, pullback_checkpointed
end



"""

eager_update!(state, model, update!)

Eagerly updates the model parameters, discarding the updated gradients to save memory.
`model` stores the parameters to be updated, `state` is the optimization state (eg. from Optimisers.jl) matching your model component, and
`update!` is the function that updates the parameters (eg. from `Optimisers.jl`), usually called as `update!(state, model, grads)`.

If `f` is a function that takes a single layer, called as `h = f(model.layers[i], h, other_args...)` then we can eagerly update with:

```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```julia

h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
```

or combine this with gradient checkpointing (for additional memory saving at the cost of increased execution time) with:

```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```julia

h = Zygote.checkpointed(f, eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
```

If `model.layers[i]` itself is callable, we can use the above by first wrapping it:

```
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```julia

f(model, xs...) = model(xs...)
h = f(Zygote.eager_update!(state.layers[i], model.layers[i], Optimisers.update!), h, other_args...)
```

!!! warning
If different layers share trainable parameters, then `eager_update!` will likely give wrong results.
"""
function eager_update!(state, model, update!)
function update_hook(dmodel)
update!(state, model, dmodel)
return nothing
end
return Zygote.hook(update_hook, model)
end

"""
hessian(f, x)

Expand Down
Loading