Skip to content

Commit b8f4493

Browse files
committed
Also add gather
1 parent 9799133 commit b8f4493

File tree

4 files changed

+76
-10
lines changed

4 files changed

+76
-10
lines changed

src/enzyme.jl

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,71 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
5757
end
5858

5959
for (dy, dx, dw) in zip(dys, dxs, dws)
60-
if !(typeof(x) <: EnzymeCore.Const) && dx !== x
61-
# dx += grad wrt x
60+
if !(typeof(x) <: EnzymeCore.Const) && dx !== x.val
61+
# dx += grad wrt x.val
6262
NNlib.∇conv_data!(dx, dy, cache_w, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
6363
end
64-
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
65-
# dw += grad wrt w
64+
if !(typeof(w) <: EnzymeCore.Const) && dw !== w.val
65+
# dw += grad wrt w.val
6666
NNlib.∇conv_filter!(dw, cache_x, dy, cdims.val; alpha=eltype(dw)(1), beta=eltype(dw)(1), kwargs...)
6767
end
6868
dy .= 0
6969
end
7070

71+
return (nothing, nothing, nothing, nothing)
72+
end
73+
74+
75+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
76+
77+
@assert !(OutType <: EnzymeCore.Const)
78+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
79+
func.val(dst.val, src.val, idx.val)
80+
end
81+
82+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
83+
dst.val
84+
else
85+
nothing
86+
end
87+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
88+
dst.dval
89+
else
90+
nothing
91+
end
92+
93+
# Cache idx if its overwritten
94+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
95+
96+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
97+
end
98+
99+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, cache_idx, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
100+
101+
# Don't cache idx if not overwritten
102+
if !(typeof(src) <: EnzymeCore.Const)
103+
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
104+
cache_idx = idx.val
105+
end
106+
end
107+
108+
ddsts = dst.dval
109+
dsrcs = src.dval
110+
111+
if EnzymeCore.EnzymeRules.width(config) == 1
112+
ddsts = (ddsts,)
113+
dsrcs = (dsrcs,)
114+
end
115+
116+
for (ddst, dsrc) in zip(ddsts, dsrcs)
117+
if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val
118+
src_size = size(src.val)
119+
NNlib.∇gather_src(ddst, src_size, cache_idx)
120+
end
121+
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
122+
ddst .= 0
123+
end
124+
end
125+
71126
return (nothing, nothing, nothing, nothing)
72127
end

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ end
870870
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
871871
cdims = DenseConvDims(x, w)
872872
gradtest((x, w) -> conv(x, w, cdims), x, w)
873-
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055
873+
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055
874874

875875
y = conv(x, w, cdims)
876876
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)

test/gather.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ function gather_testsuite(Backend)
152152
Backend == CPU ?
153153
gradtest_fn(xs -> gather(xs, idx), src) :
154154
gradtest_fn((s, i) -> gather(s, i), src, idx)
155+
156+
if Backend == CPU
157+
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated),
158+
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
159+
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
160+
161+
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue
162+
163+
EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
164+
end
165+
end
155166
end
156167

157168
@static if Test_Enzyme

test/test_utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ function gradtest(
2222
end
2323
if check_enzyme_rrule
2424
if len(xs) == 2
25-
for Tret in (Const, Active),
26-
Tx in (Const, Duplicated, BatchDuplicated),
27-
Ty in (Const, Duplicated, BatchDuplicated)
25+
for Tret in (EnzymeCore.Const, EnzymeCore.Active),
26+
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
27+
Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
2828

29-
are_activities_compatible(Tret, Tx, Ty) || continue
29+
EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue
3030

31-
test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
31+
EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
3232
end
3333
else
3434
throw(AssertionError("Unsupported arg count for testing"))

0 commit comments

Comments
 (0)