-
-
Notifications
You must be signed in to change notification settings - Fork 25
Fix #62 #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix #62 #70
Conversation
…`s of offsets (also simplifying `_aux_children`); fix broken test for issue FluxML#62
src/destructure.jl
Outdated
| if p isa ProjectTo # e.g. Array, NamedTuple | ||
| p(y) | ||
| else # p === identity for unknown structs | ||
| y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note to self, this I need to think about. Some of this complication was working around things that are now fixed in CRC.jl, if I remember right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, admittedly this line took some trial and error and is a little bit above my pay-grade. I managed to convince myself, but perhaps there's something cleaner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I finally understand what's going on. Sorry it took a while.
re constructs another Skip containing the gradient, and backing turns that into a NamedTuple with the same field names, which is what Tangent wants.
The only way I can see this failing is this: If the primal type's constructor is fussy about what types it can accept, then it may not be happy to accept something which is valid as its gradient. E.g. if there is only Skip(::AbstractLayer), and re tries to make one with a Tangent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worries! Yes, I struggled with that edge case too. Unfortunately I think it's quite tricky to work around. For example, suppose you have a user-defined functor(m::MyModel) = (m.w,), w -> .... Then:
- In general there's no way to reconstruct
MyModel(or even aNamedTupleof fields/values) withoutre, as you do not know the corresponding field name given only(m.w,), but - As you say, if the primal constructor isn't sufficiently generic then it won't be able to store
Tangent/Nothing/etc. values in it's fields and will error beforebackingcan unpack it again
Avoiding re would be ideal, but I think that would require functor to always return NamedTuples on custom structs. I noticed that this is the default in @functor, though, so maybe it's not such a painful requirement? In the mean time I can at least add a branch that would avoid re for structs that are functored to NamedTuples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact there's another problem I didn't spot before, what a mess:
julia> ac = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0]); # from tests: a,c are functor-ed, and only a is trainable
julia> v2, re2 = destructure(ac)
([1.0, 2.0], Restructure(TwoThirds, ..., 2))
julia> gradient(ac) do x # with Tangent{typeof(x), typeof(y)}(y)
w2, _ = destructure(x)
w2[2]^2
end
((a = [0.0, 4.0], b = nothing, c = [4.0, 5.0]),)
# Same, with z = backing(re(y)) :
julia> gradient(ac) do x
w2, _ = destructure(x)
w2[2]^2
end
┌ Info: last case
│ x = TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])
│ y = (a = [0.0, 4.0], c = [4.0, 5.0])
└ z = NamedTuple{(:a, :b, :c), Tuple{Any, Any, Any}}(([0.0, 4.0], [3.0], [4.0, 5.0]))
((a = [0.0, 4.0], b = [3.0], c = [4.0, 5.0]),)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yikes. That's a good example, hits all the pain points at once. If I'm understanding correctly, the gradient should be ((a = [0.0, 4.0], b = nothing, c = nothing),), right?
I think the problem is the _trainmap above; it populates the nothing values from _trainable (non-trainable fields) with the primal values, when they should be NoT. That's how the b and/or c values get back in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I think _trainmap needs to do something isnothing(t) ? NoT : f(t, a) here. That's where c = [4.0, 5.0] is coming from.
But b = [3.0] is coming from this PR's trick of calling the reconstructor made by @functor:
julia> ch, re = Functors.functor(ac)
((a = [1.0, 2.0], c = [4.0, 5.0]), var"#1#2"{TwoThirds}(TwoThirds([1.0, 2.0], [3.0], [4.0, 5.0])))
julia> re((a = [10, 20], c = nothing))
TwoThirds([10, 20], [3.0], nothing)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. So on top of the modified _trainmap to fix c, one would still have to filter backing(re(y)) to replace repopulated primal values which aren't functor-ed with NoT in order to fix b.
EDIT: But, based on the output of Tangent{typeof(x), typeof(y)}(y), maybe the modified _trainmap alone would be enough and backing(re(y)) isn't needed after all, as Tangent will assign NoT to omitted fields in y automatically.
EDIT 2: Never mind, that would still fail for children which aren't fields, like Skip.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright pushed something that works for both Skip and your TwoThirds example (modified _trainmap + filtering backing(re(y))). But since it uses re it would still fail for fussy constructors.
…h are not `trainable`; filter primal values from `backing(re(y))`
This adds a couple small changes on top of this draft PR in order to fix #62:
Offsetto fix the issue mentioned in Attempt to fix #62 #63 for array of arrays. For example, the offset structure forx = [[1.0, 2.0]]is now something likeo = [Offset(4)]which is not leaflike, compared too = [4]previously. This also opens the door to storing more information in this wrapper struct (original array size? eltype?), but that doesn't seem necessary at this timey = backing(re(y))allows forfunctor(x)to return children which aren't its own fields:yis first restructured to match the structure ofx, and then theNamedTuplebacking forre(y)is extracted and passed toTangent. It has the added benefit of adding some symmetry with_trainable_biwalkwhich naturally restructures the output of_trainmap, whereas_Tangent_biwalkpreviously did notCloses #63 (replaces).