You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
35
+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
33
36
cache_x, cache_w = cache
34
37
35
38
# Don't cache x if not overwritten and w is active (and thus required)
@@ -71,11 +74,13 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
71
74
return (nothing, nothing, nothing, nothing)
72
75
end
73
76
77
+
end
78
+
end
74
79
75
80
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
76
81
77
82
@assert!(OutType <:EnzymeCore.Const)
78
-
if OutType <:EnzymeCore.Duplicated|| OutType <:EnzymeCore.DuplicatedNoNeed
83
+
if OutType <:EnzymeCore.Duplicated|| OutType <:EnzymeCore.BatchDuplicated
79
84
func.val(dst.val, src.val, idx.val)
80
85
end
81
86
@@ -114,14 +119,76 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
0 commit comments