-
-
Notifications
You must be signed in to change notification settings - Fork 25
Fix #62 #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix #62 #70
Changes from 5 commits
57a3fa2
81e41fe
aa97f7f
6a0ba9f
927e095
ba909d2
abf8738
415b597
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,16 +53,20 @@ end | |
| Base.show(io::IO, re::Restructure{T}) where T = print(io, "Restructure(", T.name.name, ", ..., ", re.length, ")") | ||
| Base.length(re::Restructure) = re.length | ||
|
|
||
| struct Offset | ||
| i::Int | ||
| end | ||
|
|
||
| # This flattens a model, and returns a web of offsets for later use: | ||
| function _flatten(x) | ||
| isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case | ||
| isnumeric(x) && return vcat(_vec(x)), Offset(0), length(x) # trivial case | ||
| arrays = AbstractVector[] | ||
| len = Ref(0) | ||
| off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y | ||
| push!(arrays, _vec(y)) | ||
| o = len[] | ||
| len[] = o + length(y) | ||
| o | ||
| Offset(o) | ||
| end | ||
| reduce(vcat, arrays), off, len[] | ||
| end | ||
|
|
@@ -85,16 +89,18 @@ function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trai | |
| end | ||
| end | ||
|
|
||
| _getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) | ||
| _getat(y::AbstractArray, o::Int, flat::AbstractVector) = | ||
| ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes | ||
| _getat(y::Number, off::Offset, flat::AbstractVector) = ProjectTo(y)(flat[off.i + 1]) | ||
| _getat(y::AbstractArray, off::Offset, flat::AbstractVector) = | ||
| ProjectTo(y)(reshape(flat[off.i .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes | ||
|
|
||
| function _trainable_biwalk(f, x, aux) | ||
| ch, re = functor(typeof(x), x) | ||
| au, _ = functor(typeof(x), aux) | ||
| au = _aux_children(aux) | ||
| _trainmap(f, ch, _trainable(x), au) |> re | ||
| end | ||
|
|
||
| _aux_children(off) = functor(off)[1] | ||
|
|
||
| function _trainmap(f, ch, tr, aux) | ||
| map(ch, tr, aux) do c, t, a # isnothing(t) indicates non-trainable field, safe given isnumeric(c) | ||
| isnothing(t) ? c : f(t, a) | ||
|
|
@@ -103,13 +109,14 @@ end | |
|
|
||
| function _Tangent_biwalk(f, x, aux) # use with prune = NoT | ||
| ch, re = functor(typeof(x), x) | ||
| au, _ = functor(typeof(x), aux) | ||
| au = _aux_children(aux) | ||
| y = _trainmap(f, ch, _trainable(x), au) | ||
| y isa Tuple{} && return NoT | ||
| p = ProjectTo(x) | ||
| if p isa ProjectTo # e.g. Array, NamedTuple | ||
| p(y) | ||
| else # p === identity for unknown structs | ||
| y = backing(re(y)) # extract NamedTuple backing from re(y); required if x has children which aren't its own fields | ||
|
||
| Tangent{typeof(x), typeof(y)}(y) | ||
| end | ||
| end | ||
|
|
@@ -126,23 +133,23 @@ ChainRulesCore.@non_differentiable _zero(x) | |
| function _grad!(x, dx, off, flat::AbstractVector) | ||
| x′, _ = functor(typeof(x), x) | ||
| dx′, _ = functor(typeof(x), base(dx)) | ||
| off′, _ = functor(typeof(x), off) | ||
| off′ = _aux_children(off) | ||
| for (xᵢ, dxᵢ, oᵢ) in zip(x′, dx′, off′) | ||
| flat = _grad!(xᵢ, dxᵢ, oᵢ, flat) | ||
| end | ||
| flat | ||
| end | ||
| function _grad!(x, dx, off::Integer, flat::AbstractVector{T}) where T | ||
| function _grad!(x, dx, off::Offset, flat::AbstractVector{T}) where T | ||
| dx_un = unthunk(dx) | ||
| T2 = promote_type(T, eltype(dx_un)) | ||
| if T != T2 # then we must widen the type | ||
| flat = copyto!(similar(flat, T2), flat) | ||
| end | ||
| @views flat[off .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes | ||
| @views flat[off.i .+ (1:length(x))] .+= vec(dx_un) # must visit all tied nodes | ||
| flat | ||
| end | ||
| _grad!(x, dx::Zero, off, flat::AbstractVector) = flat | ||
| _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity | ||
| _grad!(x, dx::Zero, off::Offset, flat::AbstractVector) = flat # ambiguity | ||
|
|
||
| # These are only needed for 2nd derivatives: | ||
| function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.