Skip to content

Commit cd817ec

Browse files
fix anonymous walk deprecation (#125)
* fix anonymous walk deprecation * integrate review comments
1 parent 9c1ce24 commit cd817ec

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

src/Optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module Optimisers
22

3-
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children
3+
using Functors: functor, fmap, isleaf, @functor, fmapstructure, children, AbstractWalk
44
using LinearAlgebra
55

66
include("interface.jl")

src/destructure.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function _flatten(x)
6666
isnumeric(x) && return vcat(_vec(x)), 0, length(x) # trivial case
6767
arrays = AbstractVector[]
6868
len = Ref(0)
69-
off = fmap(x; exclude = isnumeric, walk = (f, z) -> map(f, _trainable(z))) do y
69+
off = fmap(x; exclude = isnumeric, walk = _TrainableStructWalk()) do y
7070
push!(arrays, _vec(y))
7171
o = len[]
7272
len[] = o + length(y)
@@ -76,18 +76,22 @@ function _flatten(x)
7676
reduce(vcat, arrays), off, len[]
7777
end
7878

79+
struct _TrainableStructWalk <: AbstractWalk end
80+
81+
(::_TrainableStructWalk)(recurse, x) = map(recurse, _trainable(x))
82+
7983
_vec(x::Number) = LinRange(x,x,1)
8084
_vec(x::AbstractArray) = vec(x)
8185

8286
function ChainRulesCore.rrule(::typeof(_flatten), x)
8387
flat, off, len = _flatten(x)
8488
_maybewarn()
85-
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk, prune = NoT))
89+
_flatten_back((dflat, _, _)) = (NoT, _rebuild(x, off, unthunk(dflat), len; walk = _Tangent_biwalk(), prune = NoT))
8690
(flat, off, len), _flatten_back
8791
end
8892

8993
# This reconstructs either a model like x, or a gradient for it:
90-
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _trainable_biwalk, kw...)
94+
function _rebuild(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...)
9195
len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))"))
9296
fmap(x, off; exclude = isnumeric, walk, kw...) do y, o
9397
_getat(y, o, flat)
@@ -98,7 +102,9 @@ _getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1])
98102
_getat(y::AbstractArray, o::Int, flat::AbstractVector) =
99103
ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes
100104

101-
function _trainable_biwalk(f, x, aux)
105+
struct _Trainable_biwalk <: AbstractWalk end
106+
107+
function (::_Trainable_biwalk)(f, x, aux)
102108
ch, re = functor(typeof(x), x)
103109
au, _ = functor(typeof(x), aux)
104110
_trainmap(f, ch, _trainable(x), au) |> re
@@ -110,7 +116,9 @@ function _trainmap(f, ch, tr, aux)
110116
end
111117
end
112118

113-
function _Tangent_biwalk(f, x, aux) # use with prune = NoT
119+
struct _Tangent_biwalk <: AbstractWalk end
120+
121+
function (::_Tangent_biwalk)(f, x, aux) # use with prune = NoT
114122
ch, re = functor(typeof(x), x)
115123
au, _ = functor(typeof(x), aux)
116124
y = _trainmap(f, ch, _trainable(x), au)
@@ -156,7 +164,7 @@ _grad!(x, dx::Zero, off::Integer, flat::AbstractVector) = flat # ambiguity
156164
# These are only needed for 2nd derivatives:
157165
function ChainRulesCore.rrule(::typeof(_grad!), x, dx, off, flat)
158166
@warn "second derivatives of Restructure may not work yet, sorry!" maxlog=3
159-
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk, prune = NoT), NoT, NoT)
167+
_grad_back(dflat) = (NoT, NoT, _rebuild(x, off, unthunk(dflat); walk = _Tangent_biwalk(), prune = NoT), NoT, NoT)
160168
_grad!(x, dx, off, flat), _grad_back
161169
end
162170
base(dx::Tangent{<:Tangent}) = backing(dx).backing # might be needed for gradient(gradient(destructure))

0 commit comments

Comments
 (0)