@@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))
55
66function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{$name} , :: Type{RT} , y:: OutType , x, w, cdims; kwargs... ) where {OutType, RT}
77
8- @assert ! (OutType <: EnzymeCore.Const )
98 if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
109 func. val (y. val, x. val, w. val, cdims. val; kwargs... )
1110 end
@@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
2221 end
2322
2423 # Cache x if its overwritten and w is active (and thus required)
25- cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ] && ! (typeof (w) <: EnzymeCore.Const ) ) ? copy (x. val) : nothing
24+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ]
25+ && ! (typeof (w) <: EnzymeCore.Const )
26+ && ! (typeof (y) <: EnzymeCore.Const )
27+ ) ? copy (x. val) : nothing
2628
2729 # Cache w if its overwritten and x is active (and thus required)
28- cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (x) <: EnzymeCore.Const ) ) ? copy (w. val) : nothing
30+ cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
31+ && ! (typeof (x) <: EnzymeCore.Const )
32+ && ! (typeof (y) <: EnzymeCore.Const )
33+ ) ? copy (w. val) : nothing
2934
3035 cache = (cache_x, cache_w)
3136
@@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
3641 cache_x, cache_w = cache
3742
3843 # Don't cache x if not overwritten and w is active (and thus required)
39- if ! (typeof (w) <: EnzymeCore.Const )
44+ if ! (typeof (w) <: EnzymeCore.Const ) && ! ( typeof (y) <: EnzymeCore.Const )
4045 if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
4146 cache_x = x. val
4247 end
4348 end
4449
4550 # Don't cache w if not overwritten and x is active (and thus required)
46- if ! (typeof (x) <: EnzymeCore.Const )
51+ if ! (typeof (x) <: EnzymeCore.Const ) && ! ( typeof (y) <: EnzymeCore.Const )
4752 if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
4853 cache_w = w. val
4954 end
@@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
6065 end
6166
6267 for (dy, dx, dw) in zip (dys, dxs, dws)
63- if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
64- # dx += grad wrt x.val
65- NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
66- end
67- if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w. val
68- # dw += grad wrt w.val
69- NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
68+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = w. val
69+
70+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
71+ # dx += grad wrt x.val
72+ NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
73+ end
74+ if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w. val
75+ # dw += grad wrt w.val
76+ NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
77+ end
78+
79+ dy .= 0
7080 end
71- dy .= 0
7281 end
7382
7483 return (nothing , nothing , nothing , nothing )
7988
8089function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
8190
82- @assert ! (OutType <: EnzymeCore.Const )
8391 if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
8492 func. val (dst. val, src. val, idx. val)
8593 end
@@ -96,15 +104,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
96104 end
97105
98106 # Cache idx if its overwritten
99- cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (src) <: EnzymeCore.Const ) ) ? copy (idx. val) : nothing
107+ cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
108+ && ! (typeof (src) <: EnzymeCore.Const )
109+ && ! (typeof (dst) <: EnzymeCore.Const )
110+ ) ? copy (idx. val) : nothing
100111
101112 return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache_idx)
102113end
103114
104115function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , cache_idx, dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
105116
106117 # Don't cache idx if not overwritten
107- if ! (typeof (src) <: EnzymeCore.Const )
118+ if ! (typeof (src) <: EnzymeCore.Const ) && ! ( typeof (dst) <: EnzymeCore.Const )
108119 if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
109120 cache_idx = idx. val
110121 end
@@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
119130 end
120131
121132 for (ddst, dsrc) in zip (ddsts, dsrcs)
122- if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val &&
123- ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
124- NNlib. scatter! (+ , dsrc, ddst, cache_idx)
125- end
126133 if ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
134+
135+ if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val
136+ NNlib. scatter! (+ , dsrc, ddst, cache_idx)
137+ end
138+
127139 ddst .= 0
128140 end
129141 end
@@ -152,15 +164,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
152164 end
153165
154166 # Cache idx if its overwritten
155- cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (src) <: EnzymeCore.Const ) ) ? copy (idx. val) : nothing
167+ cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
168+ && ! (typeof (src) <: EnzymeCore.Const )
169+ && ! (typeof (dst) <: EnzymeCore.Const )
170+ ) ? copy (idx. val) : nothing
156171
157172 return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache_idx)
158173end
159174
160175function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.scatter!)} , :: Type{RT} , cache_idx, op:: Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}} , dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
161176
162177 # Don't cache idx if not overwritten
163- if ! (typeof (src) <: EnzymeCore.Const )
178+ if ! (typeof (src) <: EnzymeCore.Const ) && ! ( typeof (dst) <: EnzymeCore.Const )
164179 if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
165180 cache_idx = idx. val
166181 end
@@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
175190 end
176191
177192 for (ddst, dsrc) in zip (ddsts, dsrcs)
178- if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val &&
179- ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
180-
181- if eltype (typeof (op)) == typeof (+ )
182- dsrc .+ = NNlib. gather (ddst, cache_idx)
183- else
184- @assert eltype (typeof (op)) == typeof (- )
185- dsrc .- = NNlib. gather (ddst, cache_idx)
193+ if ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
194+
195+ if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val
196+
197+ if eltype (typeof (op)) == typeof (+ )
198+ dsrc .+ = NNlib. gather (ddst, cache_idx)
199+ else
200+ @assert eltype (typeof (op)) == typeof (- )
201+ dsrc .- = NNlib. gather (ddst, cache_idx)
202+ end
186203 end
204+
205+ ddst .= 0
206+
187207 end
188208 end
189209
192212
193213
194214
215+ for pool in [:maxpool , :meanpool , :lpnormpool ]
216+ pool! = Symbol (pool, :! )
217+ ∇pool = Symbol (:∇ , pool)
218+
219+ @eval begin
220+
221+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof($pool!)} , :: Type{RT} , y:: OutType , x, dims; kwargs... ) where {OutType, RT}
222+
223+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
224+ func. val (y. val, x. val, dims. val; kwargs... )
225+ end
226+
227+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
228+ y. val
229+ else
230+ nothing
231+ end
232+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
233+ y. dval
234+ else
235+ nothing
236+ end
237+
238+ cache_y = ( EnzymeCore. EnzymeRules. overwritten (config)[2 ]
239+ && ! (typeof (x) <: EnzymeCore.Const )
240+ && ! (typeof (y) <: EnzymeCore.Const )
241+ ) ? copy (y. val) : nothing
242+
243+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ]
244+ && ! (typeof (x) <: EnzymeCore.Const )
245+ && ! (typeof (y) <: EnzymeCore.Const )
246+ ) ? copy (x. val) : nothing
247+
248+ cache = (cache_y, cache_x)
249+
250+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
251+ end
252+
253+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof($pool!)} , :: Type{RT} , cache, y, x, dims; kwargs... ) where {RT}
254+ cache_y, cache_x = cache
255+
256+ # Don't cache y if not overwritten
257+ if ! (typeof (x) <: EnzymeCore.Const ) && ! (typeof (y) <: EnzymeCore.Const )
258+ if ! EnzymeCore. EnzymeRules. overwritten (config)[2 ]
259+ cache_y = y. val
260+ end
261+ end
262+
263+ # Don't cache x if not overwritten
264+ if ! (typeof (x) <: EnzymeCore.Const ) && ! (typeof (y) <: EnzymeCore.Const )
265+ if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
266+ cache_x = x. val
267+ end
268+ end
269+
270+ dys = y. dval
271+ dxs = (typeof (x) <: EnzymeCore.Const ) ? dys : x. dval
272+
273+ if EnzymeCore. EnzymeRules. width (config) == 1
274+ dys = (dys,)
275+ dxs = (dxs,)
276+ end
277+
278+ for (dy, dx, dw) in zip (dys, dxs)
279+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = y. val
280+
281+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
282+ NNlib.$ (∇pool)(dx, dy, cache_y, cache_x, dims; alpha= eltype (dx)(1 ), beta= eltype (dx)(1 ), kwargs... )
283+ end
284+
285+ dy .= 0
286+ end
287+ end
288+
289+ return (nothing , nothing , nothing )
290+ end
291+
292+ end
293+ end
294+
295+
0 commit comments