-
-
Notifications
You must be signed in to change notification settings - Fork 24
Fix #62 #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix #62 #70
Changes from 5 commits
57a3fa2
81e41fe
aa97f7f
6a0ba9f
927e095
ba909d2
abf8738
415b597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,16 +53,20 @@ end | |
Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") | ||
Base.length(re::Restructure) = re.length | ||
|
||
struct Offset | ||
i::Int | ||
end | ||
|
||
# This flattens a model, and returns a web of offsets for later use: | ||
function _flatten(x) | ||
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case | ||
isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case | ||
arrays = AbstractVector[] | ||
len = Ref(0) | ||
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y | ||
push!(arrays, _vec(y)) | ||
o = len[] | ||
len[] = o + length(y) | ||
o | ||
Offset(o) | ||
end | ||
reduce(vcat, arrays), off, len[] | ||
end | ||
|
@@ -85,16 +89,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai | |
end | ||
end | ||
|
||
_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) | ||
_getat(y::AbstractArray, o::Int, flat::AbstractVector) = | ||
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes | ||
_getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1]) | ||
_getat(y::AbstractArray, off::Offset, flat::AbstractVector) = | ||
ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes | ||
|
||
function _trainable_biwalk(f, x, aux) | ||
ch, re = functor(typeof(x), x) | ||
au, _ = functor(typeof(x), aux) | ||
au = _aux_children(aux) | ||
_trainmap(f, ch, _trainable(x), au) |> re | ||
end | ||
|
||
_aux_children(off) = functor(off)[1] | ||
|
||
function _trainmap(f, ch, tr, aux) | ||
map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) | ||
isnothing(t) ? c : f(t, a) | ||
|
@@ -103,13 +109,14 @@ end | |
|
||
function _Tangent_biwalk(f, x, aux) # use with prune = NoT | ||
ch, re = functor(typeof(x), x) | ||
au, _ = functor(typeof(x), aux) | ||
au = _aux_children(aux) | ||
y = _trainmap(f, ch, _trainable(x), au) | ||
y isa Tuple{} && return NoT | ||
p = ProjectTo(x) | ||
if p isa ProjectTo # e.g. Array, NamedTuple | ||
p(y) | ||
else # p === identity for unknown structs | ||
y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I think I finally understand what's going on. Sorry it took a while.
The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined
Avoiding There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact there's another problem I didn't spot before, what a mess: julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]); # from tests: a,c are functor-ed, and only a is trainable
julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))
julia> gradient(ac) do x # with Tangent{typeof(x), typeof(y)}(y)
w2, _ = destructure(x)
w2[2]^2
end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),)
# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
w2, _ = destructure(x)
w2[2]^2
end
┌ Info: last case
│ x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│ y = (a = [0.0, 4.0], c = [4.0, 5.0])
└ z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be I think the problem is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think But
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha. So on top of the modified EDIT: But, based on the output of EDIT 2: Never mind, that would still fail for children which aren't fields, like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alright pushed something that works for both |
||
Tangent{typeof(x), typeof(y)}(y) | ||
end | ||
end | ||
|
@@ -126,23 +133,23 @@ ChainRulesCore.@non_differentiable _zero(x) | |
function _grad!(x, dx, off, flat::AbstractVector) | ||
x′, _ = functor(typeof(x), x) | ||
dx′, _ = functor(typeof(x), base(dx)) | ||
off′, _ = functor(typeof(x), off) | ||
off′ = _aux_children(off) | ||
for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′) | ||
flat = _grad!(xᵢ, dxᵢ, oᵢ, flat) | ||
end | ||
flat | ||
end | ||
function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T | ||
function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T | ||
dx_un = unthunk(dx) | ||
T2 = promote_type(T, eltype(dx_un)) | ||
if T != T2 # then we must widen the type | ||
flat = copyto!(similar(flat, T2), flat) | ||
end | ||
@views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes | ||
@views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes | ||
flat | ||
end | ||
_grad!(x, dx::Zero, off, flat::AbstractVector) = flat | ||
_grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity | ||
_grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity | ||
|
||
# These are only needed for 2nd derivatives: | ||
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) | ||
|
Uh oh!
There was an error while loading. Please reload this page.