diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index 28dc94ff39..f5b6427eca 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -629,6 +629,48 @@ end test_scatter(dsts, srcs, idxs, res; dims=[0, 1]) end + + @testset "scatter gradient" begin + dst = Float32[ + 3 3 4 4 5 + 5 5 6 6 7 + ] + dst_ca = Reactant.to_rarray(dst) + + src = ones(Float32, 2, 5) + src_ca = Reactant.to_rarray(src) + + idx = [4, 2, 1, 5, 3] + idx_ca = Reactant.to_rarray(idx) + + function test_scatter(dsts, srcs, idxs) + return sum(NNlib.scatter!(+, dsts, srcs, idxs)) + end + + function test_gradient(objective_function, dsts, srcs, idxs) + derivs, val = Enzyme.gradient( + Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI), + Const(objective_function), + dsts, + srcs, + idxs, + ) + return derivs, val + end + + test_gradient_compiled = @compile test_gradient( + test_scatter, dst_ca, src_ca, idx_ca + ) + + grads_enz, loss_enz = Enzyme.gradient( + Enzyme.ReverseWithPrimal, Const(test_scatter), dst, src, idx + ) + grads_ca, loss_ca = test_gradient_compiled(test_scatter, dst_ca, src_ca, idx_ca) + + @test grads_enz[1] ≈ Array(grads_ca[1]) + @test grads_enz[2] ≈ Array(grads_ca[2]) + @test loss_enz ≈ loss_ca + end end @testset "∇conv(D = $ndim)" for ndim in 1:3