-
-
Notifications
You must be signed in to change notification settings - Fork 36
RFC: add Functors-aware structural gradient #129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -178,3 +178,60 @@ function jacobian(f, x::AbstractVector) | |
| end | ||
|
|
||
| hessian(f, x) = jacobian(x -> gradient(f, x, nest=true)[1], x) | ||
|
|
||
| using Functors: fmap, fmapstructure | ||
| using Optimisers: _trainable, isnumeric | ||
|
|
||
| """ | ||
| withgradient(f, xs...) | ||
|
|
||
| This computes the value `f(xs...)` and the gradient with respect to `xs`. | ||
| However, it differs from `gradient` in several other respects: | ||
| * It will recurse into `xs` using `fmap`, and thus like Zygote's "explicit mode" it | ||
| returns a tree-like gradient matching the shape of a Flux model. | ||
| * Only objects satisfying `Optimisers.isnumeric` are regarded as parameters, | ||
| thus in particular integers are ignored. | ||
| * Returns plain arrays, not tracked. | ||
|
|
||
| # Examples | ||
| ``` | ||
| julia> nt = (vec = [1.0, 2.0], mat = [4.0;;], fun = sin); | ||
|
|
||
| julia> withgradient(nt, 2) do x, p | ||
| sum(abs2, x.vec) ^ p | ||
| end | ||
| (val = 25.0, grad = ((vec = [20.0, 40.0], mat = [0.0;;], fun = nothing), nothing)) | ||
|
|
||
| julia> using Flux | ||
|
|
||
| julia> model = Chain(Dense(2 => 1, tanh), Dense(1 => 1, bias=false)); | ||
|
|
||
| julia> withgradient(model, rand(Float32, 2)) do m, x | ||
| sum(abs2, m(x)) | ||
| end | ||
| (val = 0.035716165f0, grad = ((layers = ((weight = Float32[-0.4241869 -0.16741231], bias = Float32[-0.5529184], σ = nothing), (weight = Float32[-0.04804218;;], bias = nothing, σ = nothing)),), Float32[0.12706584, -0.08858479])) | ||
| ``` | ||
| """ | ||
| function withgradient(f, xs...) | ||
| pxs = fmap(param, xs; exclude = isnumeric) # would ideally apply params only to trainable | ||
|
||
| l = f(pxs...) | ||
| losscheck(l) | ||
| l isa TrackedReal || return (val = l, grad = nothing) | ||
| @interrupts back!(l) | ||
| (val = data(l), grad = rec_grad(pxs)) | ||
| end | ||
|
|
||
| # Easier to write the recursion to extract the gradients without using fmap: | ||
| rec_grad(x::TrackedArray) = grad(x) | ||
| rec_grad(x::TrackedReal) = grad(x) | ||
| rec_grad(x::AbstractArray{<:Number}) = nothing | ||
| rec_grad(x::Number) = nothing | ||
|
|
||
| rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x) | ||
| rec_grad(::Tuple{}) = nothing | ||
| rec_grad(::NamedTuple{(), Tuple{}}) = nothing | ||
| function rec_grad(x::T) where {T} | ||
| F = fieldnames(T) | ||
| isempty(F) && return nothing | ||
| map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F)) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW [email protected] required Metalhead#master right now. With that, the example from https://fluxml.ai/Optimisers.jl/dev/#Usage-with-[Flux.jl](https://github.com/FluxML/Flux.jl) runs, and has half the TTFG of Zygote:
compared to, for Zygote, this:
But something is wrong, as the final loss differs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking on the bright side, I guess with this it would be fairly easy to add checks to Flux's tests, comparing what Zygote thinks about each layer to what Tracker thinks. Any which disagree are cause for concern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, and I can already see that being helpful for Metalhead since we see the occasional odd gradient anomaly.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Diffractor, with JuliaDiff/Diffractor.jl#89