Skip to content

Commit 330c2b1

Browse files
committed
attempt fix
1 parent 8bc816d commit 330c2b1

File tree

3 files changed

+19
-7
lines changed

3 files changed

+19
-7
lines changed

src/NNlib.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,6 @@ include("impl/depthwiseconv_im2col.jl")
123123
include("impl/pooling_direct.jl")
124124
include("deprecations.jl")
125125

126-
@init @static if !isdefined(Base, :get_extension)
127-
@require EnzymeCore="f151be2c-9106-41f4-ab19-57ee4f262869" begin
128-
include("../ext/NNlibEnzymeCoreExt/NNlibEnzymeCoresExt.jl")
129-
end
130-
end
126+
include("enzyme.jl")
131127

132128
end # module NNlib

test/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ end
870870
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
871871
cdims = DenseConvDims(x, w)
872872
gradtest((x, w) -> conv(x, w, cdims), x, w)
873-
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w) # https://github.com/FluxML/Flux.jl/issues/1055
873+
gradtest((x, w) -> sum(conv(x, w, cdims)), x, w; check_enzyme_rule=true) # https://github.com/FluxML/Flux.jl/issues/1055
874874

875875
y = conv(x, w, cdims)
876876
gradtest((y, w) -> ∇conv_data(y, w, cdims), y, w)

test/test_utils.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,30 @@ Applies also `ChainRulesTestUtils.test_rrule` if the rrule for `f` is explicitly
1212
"""
1313
function gradtest(
1414
f, xs...; atol = 1e-6, rtol = 1e-6, fkwargs = NamedTuple(),
15-
check_rrule = false, fdm = :central, check_broadcast = false,
15+
check_rrule = false, check_enzyme_rrule = false, fdm = :central, check_broadcast = false,
1616
skip = false, broken = false,
1717
)
1818
# TODO: revamp when https://github.com/JuliaDiff/ChainRulesTestUtils.jl/pull/166
1919
# is merged
2020
if check_rrule
2121
test_rrule(f, xs...; fkwargs = fkwargs)
2222
end
23+
if check_enzyme_rrule
24+
if len(xs) == 2
25+
for Tret in (Const, Active),
26+
Tx in (Const, Duplicated, BatchDuplicated),
27+
Ty in (Const, Duplicated, BatchDuplicated)
28+
29+
are_activities_compatible(Tret, Tx, Ty) || continue
30+
31+
test_reverse(fun, Tret, (xs[1], Tx), (ys[1], Ty); atol, rtol)
32+
end
33+
else
34+
throw(AssertionError("Unsupported arg count for testing"))
35+
end
36+
37+
EnzymeTestUtils.test_rrule(f, xs...; fkwargs = fkwargs)
38+
end
2339

2440
if check_broadcast
2541
length(fkwargs) > 0 && @warn("CHECK_BROADCAST: dropping keywords args")

0 commit comments

Comments
 (0)