Skip to content

Commit 20fe426

Browse files
committed
Additional functions, tests, and fixes
1 parent 7419ea4 commit 20fe426

File tree

5 files changed

+133
-34
lines changed

5 files changed

+133
-34
lines changed

src/enzyme.jl

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import EnzymeCore
22

3-
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
3+
for name in (typeof(NNlib.conv!), typeof(NNlib.depthwiseconv!))
4+
@eval begin
5+
6+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{$name}, ::Type{RT}, y::OutType, x, w, cdims; kwargs...) where {OutType, RT}
47

58
@assert !(OutType <: EnzymeCore.Const)
6-
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
9+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
710
func.val(y.val, x.val, w.val, cdims.val; kwargs...)
811
end
912

@@ -29,7 +32,7 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{
2932
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache)
3033
end
3134

32-
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.conv!)}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
35+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{$name}, ::Type{RT}, cache, y, x, w, cdims; kwargs...) where {RT}
3336
cache_x, cache_w = cache
3437

3538
# Don't cache x if not overwritten and w is active (and thus required)
@@ -71,11 +74,13 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
7174
return (nothing, nothing, nothing, nothing)
7275
end
7376

77+
end
78+
end
7479

7580
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.gather!)}, ::Type{RT}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
7681

7782
@assert !(OutType <: EnzymeCore.Const)
78-
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.DuplicatedNoNeed
83+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
7984
func.val(dst.val, src.val, idx.val)
8085
end
8186

@@ -114,14 +119,76 @@ function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NN
114119
end
115120

116121
for (ddst, dsrc) in zip(ddsts, dsrcs)
117-
if !(typeof(src) <: EnzymeCore.Const) && ddst !== dst.val
118-
src_size = size(src.val)
119-
NNlib.∇gather_src(ddst, src_size, cache_idx)
122+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
123+
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
124+
NNlib.scatter!(+, dsrc, ddst, cache_idx)
120125
end
121-
if !(typeof(w) <: EnzymeCore.Const) && dw !== w
126+
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
122127
ddst .= 0
123128
end
124129
end
125130

131+
return (nothing, nothing, nothing)
132+
end
133+
134+
135+
136+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, op::EnzymeCore.Const, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
137+
138+
@assert !(OutType <: EnzymeCore.Const)
139+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
140+
func.val(op.val, dst.val, src.val, idx.val)
141+
end
142+
143+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
144+
dst.val
145+
else
146+
nothing
147+
end
148+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
149+
dst.dval
150+
else
151+
nothing
152+
end
153+
154+
# Cache idx if its overwritten
155+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] && !(typeof(src) <: EnzymeCore.Const) ) ? copy(idx.val) : nothing
156+
157+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, cache_idx)
158+
end
159+
160+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib.scatter!)}, ::Type{RT}, cache_idx, op::Union{EnzymeCore.Const{typeof(+)},EnzymeCore.Const{typeof(-)}}, dst::OutType, src, idx::EnzymeCore.Const) where {OutType, RT}
161+
162+
# Don't cache idx if not overwritten
163+
if !(typeof(src) <: EnzymeCore.Const)
164+
if !EnzymeCore.EnzymeRules.overwritten(config)[4]
165+
cache_idx = idx.val
166+
end
167+
end
168+
169+
ddsts = dst.dval
170+
dsrcs = src.dval
171+
172+
if EnzymeCore.EnzymeRules.width(config) == 1
173+
ddsts = (ddsts,)
174+
dsrcs = (dsrcs,)
175+
end
176+
177+
for (ddst, dsrc) in zip(ddsts, dsrcs)
178+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val &&
179+
!(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
180+
181+
if eltype(typeof(op)) == typeof(+)
182+
dsrc .+= NNlib.gather(ddst, cache_idx)
183+
else
184+
@assert eltype(typeof(op)) == typeof(-)
185+
dsrc .-= NNlib.gather(ddst, cache_idx)
186+
end
187+
end
188+
end
189+
126190
return (nothing, nothing, nothing, nothing)
127-
end
191+
end
192+
193+
194+

test/conv.jl

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -861,7 +861,7 @@ end
861861
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
862862
cdims = DenseConvDims(x, w)
863863
gradtest((x, w) -> conv(x, w, cdims), x, w)
864-
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rrule=true) # https://github.com/FluxML/Flux.jl/issues/1055
864+
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055
865865

866866
y = conv(x, w, cdims)
867867
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)
@@ -877,3 +877,25 @@ end
877877
gradtest((y, w) -> ∇depthwiseconv_data(y, w, dcdims), y, w)
878878
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
879879
end
880+
881+
@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
882+
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
883+
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
884+
cdims = DenseConvDims(x, w)
885+
886+
for name in (:conv, :depthwiseconv)
887+
curconv = @eval $(Symbol("$(name)"))
888+
curconv! = @eval $(Symbol("$(name)!"))
889+
dst = curconv(x, w, cdims)
890+
891+
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
892+
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
893+
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
894+
Tw in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
895+
896+
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue
897+
898+
EnzymeTestUtils.test_reverse(curconv!, Tret, (dst, Tdst), (x, Tx), (x, Tw), (idx, EnzymeCore.Const))
899+
end
900+
end
901+
end

test/gather.jl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,23 @@ function gather_testsuite(Backend)
152152
Backend == CPU ?
153153
gradtest_fn(xs -> gather(xs, idx), src) :
154154
gradtest_fn((s, i) -> gather(s, i), src, idx)
155+
end
155156

156-
if Backend == CPU
157-
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated),
158-
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
159-
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
160157

161-
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue
158+
@testset "EnzymeRules: gather! gradient for scalar index" begin
159+
src = device(Float64[3, 4, 5, 6, 7])
160+
idx = device([
161+
1 2 3 4;
162+
4 2 1 3;
163+
3 5 5 3])
164+
dst = gather(src, idx)
165+
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
166+
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
167+
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
168+
169+
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue
162170

163-
EnzymeTestUtils.test_reverse(fun, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
164-
end
171+
EnzymeTestUtils.test_reverse(gather!, Tret, (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
165172
end
166173
end
167174

test/scatter.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,5 +207,24 @@ function scatter_testsuite(Backend)
207207
gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx)
208208
end
209209
end
210+
211+
@testset "EnzymeRules" begin
212+
idx = device([2, 2, 3, 4, 4])
213+
src = device(ones(T, 3, 5))
214+
215+
for op in (+, -)
216+
217+
dst = scatter(op, src, idx)
218+
219+
for Tret in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
220+
Tdst in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
221+
Tsrc in (EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
222+
223+
EnzymeTestUtils.are_activities_compatible(Tret, Tdst, Tsrc) || continue
224+
225+
EnzymeTestUtils.test_reverse(scatter!, Tret, (op, EnzymeCore.Const), (dst, Tdst), (src, Tsrc), (idx, EnzymeCore.Const))
226+
end
227+
end
228+
end
210229
end
211230
end

test/test_utils.jl

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,14 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly
1212
"""
1313
function gradtest(
1414
f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(),
15-
check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false,
15+
check_rrule = false, fdm = :central, check_broadcast = false,
1616
skip = false, broken = false,
1717
)
1818
# TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166
1919
# is merged
2020
if check_rrule
2121
test_rrule(f, xs...; fkwargs = fkwargs)
2222
end
23-
if check_enzyme_rrule
24-
if len(xs) == 2
25-
for Tret in (EnzymeCore.Const, EnzymeCore.Active),
26-
Tx in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated),
27-
Ty in (EnzymeCore.Const, EnzymeCore.Duplicated, EnzymeCore.BatchDuplicated)
28-
29-
EnzymeTestUtils.are_activities_compatible(Tret, Tx, Ty) || continue
30-
31-
EnzymeTestUtils.test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
32-
end
33-
else
34-
throw(AssertionError("Unsupported arg count for testing"))
35-
end
36-
37-
EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs)
38-
end
3923

4024
if check_broadcast
4125
length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args")

0 commit comments

Comments
 (0)