1
+ import EnzymeCore
2
+
3
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(NNlib.conv!)} , :: Type{RT} , y:: OutType , x, w, cdims; kwargs... ) where {OutType, RT}
4
+
5
+ @assert ! (OutType <: EnzymeCore.Const )
6
+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
7
+ func. val (y. val, x. val, w. val, cdims. val; kwargs... )
8
+ end
9
+
10
+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
11
+ y. val
12
+ else
13
+ nothing
14
+ end
15
+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
16
+ y. dval
17
+ else
18
+ nothing
19
+ end
20
+
21
+ # Cache x if its overwritten and w is active (and thus required)
22
+ cache_x = ( EnzymeCore. EnzymeRules. overwritten (config)[3 ] && ! (typeof (w) <: EnzymeCore.Const ) ) ? copy (x. val) : nothing
23
+
24
+ # Cache w if its overwritten and x is active (and thus required)
25
+ cache_w = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (x) <: EnzymeCore.Const ) ) ? copy (w. val) : nothing
26
+
27
+ cache = (cache_x, cache_w)
28
+
29
+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache)
30
+ end
31
+
32
+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.conv!)} , :: Type{RT} , cache, y, x, w, cdims; kwargs... ) where {RT}
33
+ cache_x, cache_w = cache
34
+
35
+ # Don't cache x if not overwritten and w is active (and thus required)
36
+ if ! (typeof (w) <: EnzymeCore.Const )
37
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[3 ]
38
+ cache_x = x. val
39
+ end
40
+ end
41
+
42
+ # Don't cache w if not overwritten and x is active (and thus required)
43
+ if ! (typeof (x) <: EnzymeCore.Const )
44
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
45
+ cache_w = w. val
46
+ end
47
+ end
48
+
49
+ dys = y. dval
50
+ dxs = (typeof (x) <: EnzymeCore.Const ) ? dys : x. dval
51
+ dws = (typeof (w) <: EnzymeCore.Const ) ? dys : w. dval
52
+
53
+ if EnzymeCore. EnzymeRules. width (config) == 1
54
+ dys = (dys,)
55
+ dxs = (dxs,)
56
+ dws = (dws,)
57
+ end
58
+
59
+ for (dy, dx, dw) in zip (dys, dxs, dws)
60
+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x
61
+ # dx += grad wrt x
62
+ NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
63
+ end
64
+ if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w
65
+ # dw += grad wrt w
66
+ NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
67
+ end
68
+ dy .= 0
69
+ end
70
+
71
+ return (nothing , nothing , nothing , nothing )
72
+ end
0 commit comments