@@ -5,7 +5,6 @@ for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))
5
5
6
6
function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{$name} , :: Type{RT} , y:: OutType , x, w, cdims; kwargs... ) where {OutType, RT}
7
7
8
- @assert ! (OutType <: EnzymeCore.Const )
9
8
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
10
9
func. val (y. val, x. val, w. val, cdims. val; kwargs... )
11
10
end
@@ -22,10 +21,16 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
22
21
end
23
22
24
23
# Cache x if its overwritten and w is active (and thus required)
25
- cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ] && ! (typeof (w) <: EnzymeCore.Const ) ) ? copy (x. val) : nothing
24
+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ]
25
+ && ! (typeof (w) <: EnzymeCore.Const )
26
+ && ! (typeof (y) <: EnzymeCore.Const )
27
+ ) ? copy (x. val) : nothing
26
28
27
29
# Cache w if its overwritten and x is active (and thus required)
28
- cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (x) <: EnzymeCore.Const ) ) ? copy (w. val) : nothing
30
+ cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
31
+ && ! (typeof (x) <: EnzymeCore.Const )
32
+ && ! (typeof (y) <: EnzymeCore.Const )
33
+ ) ? copy (w. val) : nothing
29
34
30
35
cache = (cache_x, cache_w)
31
36
@@ -36,14 +41,14 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
36
41
cache_x, cache_w = cache
37
42
38
43
# Don't cache x if not overwritten and w is active (and thus required)
39
- if ! (typeof (w) <: EnzymeCore.Const )
44
+ if ! (typeof (w) <: EnzymeCore.Const ) && ! ( typeof (y) <: EnzymeCore.Const )
40
45
if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
41
46
cache_x = x. val
42
47
end
43
48
end
44
49
45
50
# Don't cache w if not overwritten and x is active (and thus required)
46
- if ! (typeof (x) <: EnzymeCore.Const )
51
+ if ! (typeof (x) <: EnzymeCore.Const ) && ! ( typeof (y) <: EnzymeCore.Const )
47
52
if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
48
53
cache_w = w. val
49
54
end
@@ -60,15 +65,19 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, :
60
65
end
61
66
62
67
for (dy, dx, dw) in zip (dys, dxs, dws)
63
- if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
64
- # dx += grad wrt x.val
65
- NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
66
- end
67
- if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w. val
68
- # dw += grad wrt w.val
69
- NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
68
+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = w. val
69
+
70
+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
71
+ # dx += grad wrt x.val
72
+ NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
73
+ end
74
+ if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w. val
75
+ # dw += grad wrt w.val
76
+ NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
77
+ end
78
+
79
+ dy .= 0
70
80
end
71
- dy .= 0
72
81
end
73
82
74
83
return (nothing , nothing , nothing , nothing )
79
88
80
89
function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
81
90
82
- @assert ! (OutType <: EnzymeCore.Const )
83
91
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
84
92
func. val (dst. val, src. val, idx. val)
85
93
end
@@ -96,15 +104,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
96
104
end
97
105
98
106
# Cache idx if its overwritten
99
- cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (src) <: EnzymeCore.Const ) ) ? copy (idx. val) : nothing
107
+ cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
108
+ && ! (typeof (src) <: EnzymeCore.Const )
109
+ && ! (typeof (dst) <: EnzymeCore.Const )
110
+ ) ? copy (idx. val) : nothing
100
111
101
112
return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache_idx)
102
113
end
103
114
104
115
function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , cache_idx, dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
105
116
106
117
# Don't cache idx if not overwritten
107
- if ! (typeof (src) <: EnzymeCore.Const )
118
+ if ! (typeof (src) <: EnzymeCore.Const ) && ! ( typeof (dst) <: EnzymeCore.Const )
108
119
if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
109
120
cache_idx = idx. val
110
121
end
@@ -119,11 +130,12 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
119
130
end
120
131
121
132
for (ddst, dsrc) in zip (ddsts, dsrcs)
122
- if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val &&
123
- ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
124
- NNlib. scatter! (+ , dsrc, ddst, cache_idx)
125
- end
126
133
if ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
134
+
135
+ if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val
136
+ NNlib. scatter! (+ , dsrc, ddst, cache_idx)
137
+ end
138
+
127
139
ddst .= 0
128
140
end
129
141
end
@@ -152,15 +164,18 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
152
164
end
153
165
154
166
# Cache idx if its overwritten
155
- cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (src) <: EnzymeCore.Const ) ) ? copy (idx. val) : nothing
167
+ cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ]
168
+ && ! (typeof (src) <: EnzymeCore.Const )
169
+ && ! (typeof (dst) <: EnzymeCore.Const )
170
+ ) ? copy (idx. val) : nothing
156
171
157
172
return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache_idx)
158
173
end
159
174
160
175
function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.scatter!)} , :: Type{RT} , cache_idx, op:: Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}} , dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
161
176
162
177
# Don't cache idx if not overwritten
163
- if ! (typeof (src) <: EnzymeCore.Const )
178
+ if ! (typeof (src) <: EnzymeCore.Const ) && ! ( typeof (dst) <: EnzymeCore.Const )
164
179
if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
165
180
cache_idx = idx. val
166
181
end
@@ -175,15 +190,20 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
175
190
end
176
191
177
192
for (ddst, dsrc) in zip (ddsts, dsrcs)
178
- if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val &&
179
- ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
180
-
181
- if eltype (typeof (op)) == typeof (+ )
182
- dsrc .+ = NNlib. gather (ddst, cache_idx)
183
- else
184
- @assert eltype (typeof (op)) == typeof (- )
185
- dsrc .- = NNlib. gather (ddst, cache_idx)
193
+ if ! (typeof (dst) <: EnzymeCore.Const ) && ddst != = dst. val
194
+
195
+ if ! (typeof (src) <: EnzymeCore.Const ) && dsrc != = src. val
196
+
197
+ if eltype (typeof (op)) == typeof (+ )
198
+ dsrc .+ = NNlib. gather (ddst, cache_idx)
199
+ else
200
+ @assert eltype (typeof (op)) == typeof (- )
201
+ dsrc .- = NNlib. gather (ddst, cache_idx)
202
+ end
186
203
end
204
+
205
+ ddst .= 0
206
+
187
207
end
188
208
end
189
209
192
212
193
213
194
214
215
+ for pool in [:maxpool , :meanpool , :lpnormpool ]
216
+ pool! = Symbol (pool, :! )
217
+ ∇pool = Symbol (:∇ , pool)
218
+
219
+ @eval begin
220
+
221
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof($pool!)} , :: Type{RT} , y:: OutType , x, dims; kwargs... ) where {OutType, RT}
222
+
223
+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
224
+ func. val (y. val, x. val, dims. val; kwargs... )
225
+ end
226
+
227
+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
228
+ y. val
229
+ else
230
+ nothing
231
+ end
232
+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
233
+ y. dval
234
+ else
235
+ nothing
236
+ end
237
+
238
+ cache_y = ( EnzymeCore. EnzymeRules. overwritten (config)[2 ]
239
+ && ! (typeof (x) <: EnzymeCore.Const )
240
+ && ! (typeof (y) <: EnzymeCore.Const )
241
+ ) ? copy (y. val) : nothing
242
+
243
+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ]
244
+ && ! (typeof (x) <: EnzymeCore.Const )
245
+ && ! (typeof (y) <: EnzymeCore.Const )
246
+ ) ? copy (x. val) : nothing
247
+
248
+ cache = (cache_y, cache_x)
249
+
250
+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
251
+ end
252
+
253
+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof($pool!)} , :: Type{RT} , cache, y, x, dims; kwargs... ) where {RT}
254
+ cache_y, cache_x = cache
255
+
256
+ # Don't cache y if not overwritten
257
+ if ! (typeof (x) <: EnzymeCore.Const ) && ! (typeof (y) <: EnzymeCore.Const )
258
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[2 ]
259
+ cache_y = y. val
260
+ end
261
+ end
262
+
263
+ # Don't cache x if not overwritten
264
+ if ! (typeof (x) <: EnzymeCore.Const ) && ! (typeof (y) <: EnzymeCore.Const )
265
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
266
+ cache_x = x. val
267
+ end
268
+ end
269
+
270
+ dys = y. dval
271
+ dxs = (typeof (x) <: EnzymeCore.Const ) ? dys : x. dval
272
+
273
+ if EnzymeCore. EnzymeRules. width (config) == 1
274
+ dys = (dys,)
275
+ dxs = (dxs,)
276
+ end
277
+
278
+ for (dy, dx, dw) in zip (dys, dxs)
279
+ if ! (typeof (y) <: EnzymeCore.Const ) && dy != = y. val
280
+
281
+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
282
+ NNlib.$ (∇pool)(dx, dy, cache_y, cache_x, dims; alpha= eltype (dx)(1 ), beta= eltype (dx)(1 ), kwargs... )
283
+ end
284
+
285
+ dy .= 0
286
+ end
287
+ end
288
+
289
+ return (nothing , nothing , nothing )
290
+ end
291
+
292
+ end
293
+ end
294
+
295
+
0 commit comments