Skip to content

Commit c32de74

Browse files
committed
Add pooling
1 parent e6e98b4 commit c32de74

File tree

1 file changed

+131
-30
lines changed

1 file changed

+131
-30
lines changed

src/enzyme.jl

Lines changed: 131 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))
55

66
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
77

8-
@assert !(OutType <: EnzymeCore.Const)
98
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
109
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
1110
end
@@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
2221
end
2322

2423
# Cache x if its overwritten and w is active (and thus required)
25-
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3] && !(typeof(w) <: EnzymeCore.Const) ) ? copy(x.val) : nothing
24+
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3]
25+
&& !(typeof(w) <: EnzymeCore.Const)
26+
&& !(typeof(y) <: EnzymeCore.Const)
27+
) ? copy(x.val) : nothing
2628

2729
# Cache w if its overwritten and x is active (and thus required)
28-
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(x) <: EnzymeCore.Const) ) ? copy(w.val) : nothing
30+
cache_w = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
31+
&& !(typeof(x) <: EnzymeCore.Const)
32+
&& !(typeof(y) <: EnzymeCore.Const)
33+
) ? copy(w.val) : nothing
2934

3035
cache = (cache_x, cache_w)
3136

@@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
3641
cache_x, cache_w = cache
3742

3843
# Don't cache x if not overwritten and w is active (and thus required)
39-
if !(typeof(w) <: EnzymeCore.Const)
44+
if !(typeof(w) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
4045
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
4146
cache_x = x.val
4247
end
4348
end
4449

4550
# Don't cache w if not overwritten and x is active (and thus required)
46-
if !(typeof(x) <: EnzymeCore.Const)
51+
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
4752
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
4853
cache_w = w.val
4954
end
@@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
6065
end
6166

6267
for (dy, dx, dw) in zip(dys, dxs, dws)
63-
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
64-
# dx += grad wrt x.val
65-
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
66-
end
67-
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
68-
# dw += grad wrt w.val
69-
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
68+
if !(typeof(y) <: EnzymeCore.Const) && dy !== w.val
69+
70+
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
71+
# dx += grad wrt x.val
72+
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
73+
end
74+
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
75+
# dw += grad wrt w.val
76+
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
77+
end
78+
79+
dy .= 0
7080
end
71-
dy .= 0
7281
end
7382

7483
return (nothing, nothing, nothing, nothing)
@@ -79,7 +88,6 @@ end
7988

8089
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
8190

82-
@assert !(OutType <: EnzymeCore.Const)
8391
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
8492
func.val(dst.val, src.val, idx.val)
8593
end
@@ -96,15 +104,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
96104
end
97105

98106
# Cache idx if its overwritten
99-
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
107+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
108+
&& !(typeof(src) <: EnzymeCore.Const)
109+
&& !(typeof(dst) <: EnzymeCore.Const)
110+
) ? copy(idx.val) : nothing
100111

101112
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
102113
end
103114

104115
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
105116

106117
# Don't cache idx if not overwritten
107-
if !(typeof(src) <: EnzymeCore.Const)
118+
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
108119
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
109120
cache_idx = idx.val
110121
end
@@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
119130
end
120131

121132
for (ddst, dsrc) in zip(ddsts, dsrcs)
122-
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
123-
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
124-
NNlib.scatter!(+, dsrc, ddst, cache_idx)
125-
end
126133
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
134+
135+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
136+
NNlib.scatter!(+, dsrc, ddst, cache_idx)
137+
end
138+
127139
ddst .= 0
128140
end
129141
end
@@ -152,15 +164,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
152164
end
153165

154166
# Cache idx if its overwritten
155-
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
167+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
168+
&& !(typeof(src) <: EnzymeCore.Const)
169+
&& !(typeof(dst) <: EnzymeCore.Const)
170+
) ? copy(idx.val) : nothing
156171

157172
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
158173
end
159174

160175
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
161176

162177
# Don't cache idx if not overwritten
163-
if !(typeof(src) <: EnzymeCore.Const)
178+
if !(typeof(src) <: EnzymeCore.Const) && !(typeof(dst) <: EnzymeCore.Const)
164179
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
165180
cache_idx = idx.val
166181
end
@@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
175190
end
176191

177192
for (ddst, dsrc) in zip(ddsts, dsrcs)
178-
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
179-
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
180-
181-
if eltype(typeof(op)) == typeof(+)
182-
dsrc .+= NNlib.gather(ddst, cache_idx)
183-
else
184-
@assert eltype(typeof(op)) == typeof(-)
185-
dsrc .-= NNlib.gather(ddst, cache_idx)
193+
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
194+
195+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
196+
197+
if eltype(typeof(op)) == typeof(+)
198+
dsrc .+= NNlib.gather(ddst, cache_idx)
199+
else
200+
@assert eltype(typeof(op)) == typeof(-)
201+
dsrc .-= NNlib.gather(ddst, cache_idx)
202+
end
186203
end
204+
205+
ddst .= 0
206+
187207
end
188208
end
189209

@@ -192,3 +212,84 @@ end
192212

193213

194214

215+
for pool in [:maxpool, :meanpool, :lpnormpool]
216+
pool! = Symbol(pool, :!)
217+
∇pool = Symbol(:∇, pool)
218+
219+
@eval begin
220+
221+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, y::OutType, x, dims; kwargs...) where {OutType, RT}
222+
223+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
224+
func.val(y.val, x.val, dims.val; kwargs...)
225+
end
226+
227+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
228+
y.val
229+
else
230+
nothing
231+
end
232+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
233+
y.dval
234+
else
235+
nothing
236+
end
237+
238+
cache_y = ( EnzymeCore.EnzymeRules.overwritten(config)[2]
239+
&& !(typeof(x) <: EnzymeCore.Const)
240+
&& !(typeof(y) <: EnzymeCore.Const)
241+
) ? copy(y.val) : nothing
242+
243+
cache_x = ( EnzymeCore.EnzymeRules.overwritten(config)[3]
244+
&& !(typeof(x) <: EnzymeCore.Const)
245+
&& !(typeof(y) <: EnzymeCore.Const)
246+
) ? copy(x.val) : nothing
247+
248+
cache = (cache_y, cache_x)
249+
250+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
251+
end
252+
253+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof($pool!)}, ::Type{RT}, cache, y, x, dims; kwargs...) where {RT}
254+
cache_y, cache_x = cache
255+
256+
# Don't cache y if not overwritten
257+
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
258+
if !EnzymeCore.EnzymeRules.overwritten(config)[2]
259+
cache_y = y.val
260+
end
261+
end
262+
263+
# Don't cache x if not overwritten
264+
if !(typeof(x) <: EnzymeCore.Const) && !(typeof(y) <: EnzymeCore.Const)
265+
if !EnzymeCore.EnzymeRules.overwritten(config)[3]
266+
cache_x = x.val
267+
end
268+
end
269+
270+
dys = y.dval
271+
dxs = (typeof(x) <: EnzymeCore.Const) ? dys : x.dval
272+
273+
if EnzymeCore.EnzymeRules.width(config) == 1
274+
dys = (dys,)
275+
dxs = (dxs,)
276+
end
277+
278+
for (dy, dx, dw) in zip(dys, dxs)
279+
if !(typeof(y) <: EnzymeCore.Const) && dy !== y.val
280+
281+
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
282+
NNlib.$(∇pool)(dx, dy, cache_y, cache_x, dims; alpha=eltype(dx)(1), beta=eltype(dx)(1), kwargs...)
283+
end
284+
285+
dy .= 0
286+
end
287+
end
288+
289+
return (nothing, nothing, nothing)
290+
end
291+
292+
end
293+
end
294+
295+

0 commit comments

Comments
 (0)