@@ -57,16 +57,71 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
57
57
end
58
58
59
59
for (dy, dx, dw) in zip (dys, dxs, dws)
60
- if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x
61
- # dx += grad wrt x
60
+ if ! (typeof (x) <: EnzymeCore.Const ) && dx != = x. val
61
+ # dx += grad wrt x.val
62
62
NNlib.∇conv_data! (dx, dy, cache_w, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
63
63
end
64
- if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w
65
- # dw += grad wrt w
64
+ if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w. val
65
+ # dw += grad wrt w.val
66
66
NNlib.∇conv_filter! (dw, cache_x, dy, cdims. val; alpha= eltype (dw)(1 ), beta= eltype (dw)(1 ), kwargs... )
67
67
end
68
68
dy .= 0
69
69
end
70
70
71
+ return (nothing , nothing , nothing , nothing )
72
+ end
73
+
74
+
75
+ function EnzymeCore. EnzymeRules. augmented_primal (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
76
+
77
+ @assert ! (OutType <: EnzymeCore.Const )
78
+ if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
79
+ func. val (dst. val, src. val, idx. val)
80
+ end
81
+
82
+ primal = if EnzymeCore. EnzymeRules. needs_primal (config)
83
+ dst. val
84
+ else
85
+ nothing
86
+ end
87
+ shadow = if EnzymeCore. EnzymeRules. needs_shadow (config)
88
+ dst. dval
89
+ else
90
+ nothing
91
+ end
92
+
93
+ # Cache idx if its overwritten
94
+ cache_idx = ( EnzymeCore. EnzymeRules. overwritten (config)[4 ] && ! (typeof (src) <: EnzymeCore.Const ) ) ? copy (idx. val) : nothing
95
+
96
+ return EnzymeCore. EnzymeRules. AugmentedReturn (primal, shadow, cache_idx)
97
+ end
98
+
99
+ function EnzymeCore. EnzymeRules. reverse (config, func:: EnzymeCore.Const{typeof(NNlib.gather!)} , :: Type{RT} , cache_idx, dst:: OutType , src, idx:: EnzymeCore.Const ) where {OutType, RT}
100
+
101
+ # Don't cache idx if not overwritten
102
+ if ! (typeof (src) <: EnzymeCore.Const )
103
+ if ! EnzymeCore. EnzymeRules. overwritten (config)[4 ]
104
+ cache_idx = idx. val
105
+ end
106
+ end
107
+
108
+ ddsts = dst. dval
109
+ dsrcs = src. dval
110
+
111
+ if EnzymeCore. EnzymeRules. width (config) == 1
112
+ ddsts = (ddsts,)
113
+ dsrcs = (dsrcs,)
114
+ end
115
+
116
+ for (ddst, dsrc) in zip (ddsts, dsrcs)
117
+ if ! (typeof (src) <: EnzymeCore.Const ) && ddst != = dst. val
118
+ src_size = size (src. val)
119
+ NNlib.∇gather_src (ddst, src_size, cache_idx)
120
+ end
121
+ if ! (typeof (w) <: EnzymeCore.Const ) && dw != = w
122
+ ddst .= 0
123
+ end
124
+ end
125
+
71
126
return (nothing , nothing , nothing , nothing )
72
127
end
0 commit comments