Skip to content

Commit 115ac25

Browse files
feat: NNlib.scatter (#1395)
* First version of NNlib.scatter code & tests * Support for higher scatter dims + refactoring * Added support for mean in NNlib.scatter * refactor: reuse scatter impl * fix: emit better scatter --------- Co-authored-by: Julian Trommer <[email protected]>
1 parent e11d459 commit 115ac25

File tree

6 files changed

+345
-2
lines changed

6 files changed

+345
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
5252
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
5353
ReactantKernelAbstractionsExt = "KernelAbstractions"
5454
ReactantMPIExt = "MPI"
55-
ReactantNNlibExt = "NNlib"
55+
ReactantNNlibExt = ["NNlib", "Statistics"]
5656
ReactantOffsetArraysExt = "OffsetArrays"
5757
ReactantOneHotArraysExt = "OneHotArrays"
5858
ReactantPythonCallExt = "PythonCall"

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,3 +487,90 @@ function NNlib.upsample_linear_kernel!(
487487
copyto!(y, upsample_linear(x, size(y)[1:(end - 2)], ratios..., align_corners))
488488
return y
489489
end
490+
491+
# Scatter
492+
function NNlib.scatter(
493+
op::OP, src::AnyTracedRArray{T}, idx::AbstractArray; init=nothing, dstsize=nothing
494+
) where {OP,T}
495+
dims = ndims(src) - ndims(idx)
496+
dstsz = if isnothing(dstsize)
497+
(size(src)[1:dims]..., NNlib.maximum_dims(idx)...)
498+
else
499+
dstsize
500+
end
501+
if any(d -> d isa TracedRNumber, dstsz)
502+
throw(
503+
ArgumentError(
504+
"dstsize must be specified when idx is a TracedRArray or contains a TracedRNumber.",
505+
),
506+
)
507+
end
508+
xinit = isnothing(init) ? NNlib.scatter_empty(op, T) : init
509+
dst = Ops.fill(xinit, dstsz)
510+
511+
NNlib.scatter!(op, dst, src, idx)
512+
return dst
513+
end
514+
515+
function NNlib.scatter!(
516+
op::OP, dst::AnyTracedRArray, src::AnyTracedRArray, idx::AbstractArray
517+
) where {OP}
518+
dims = NNlib.scatter_dims(dst, src, idx)
519+
res = _nnlib_scatter_impl(op, dst, src, _stack_indices(idx), dims)
520+
set_mlir_data!(dst, get_mlir_data(res))
521+
return dst
522+
end
523+
524+
function NNlib.scatter!(
525+
op::OP, dst::AnyTracedRArray, src::AnyTracedRArray, idx::AbstractArray{<:Number}
526+
) where {OP}
527+
dims = NNlib.scatter_dims(dst, src, idx)
528+
res = _nnlib_scatter_impl(op, dst, src, reshape(idx, 1, size(idx)...), dims)
529+
set_mlir_data!(dst, get_mlir_data(res))
530+
return dst
531+
end
532+
533+
for AT in (AbstractArray, AbstractArray{<:Number})
534+
@eval function NNlib.scatter!(
535+
::typeof(mean), dst::AnyTracedRArray, src::AnyTracedRArray, idx::$AT
536+
)
537+
Ns = NNlib.scatter!(+, zero(dst), one.(src), idx)
538+
dst_ = NNlib.scatter!(+, zero(dst), src, idx)
539+
res = dst .+ NNlib.safe_div.(dst_, Ns)
540+
set_mlir_data!(dst, get_mlir_data(res))
541+
return dst
542+
end
543+
end
544+
545+
function _nnlib_scatter_impl(
546+
op::OP,
547+
dst::AnyTracedRArray{T},
548+
src::AnyTracedRArray{T},
549+
idx::AbstractArray,
550+
n_dims::Int,
551+
) where {OP,T}
552+
scatter_indices = TracedUtils.promote_to(TracedRArray{Int,ndims(idx)}, idx)
553+
n_idxs = size(scatter_indices, 1)
554+
return Ops.scatter(
555+
op,
556+
[dst],
557+
scatter_indices,
558+
[src];
559+
update_window_dims=collect(Int64, 1:n_dims),
560+
inserted_window_dims=collect(Int64, (n_dims + 1):ndims(dst)),
561+
input_batching_dims=Int64[],
562+
scatter_indices_batching_dims=Int64[],
563+
scatter_dims_to_operand_dims=collect(Int64, (ndims(dst) - n_idxs + 1):ndims(dst)),
564+
index_vector_dim=Int64(1),
565+
)[1]
566+
end
567+
568+
function NNlib.maximum_dims(dims::AnyTracedRArray{<:Integer})
569+
return (maximum(dims),)
570+
end
571+
function NNlib.maximum_dims(dims::AnyTracedRArray{NTuple{N,T}}) where {N,T}
572+
return ntuple(i -> maximum(x -> x[i], dims), N)
573+
end
574+
function NNlib.maximum_dims(dims::AnyTracedRArray{CartesianIndex{N}}) where {N}
575+
return ntuple(i -> maximum(x -> x[i], dims), N)
576+
end

ext/ReactantNNlibExt/ReactantNNlibExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Reactant.TracedUtils:
1010

1111
using ReactantCore: @trace
1212
using LinearAlgebra: LinearAlgebra, triu
13+
using Statistics: mean
1314

1415
include("Overlay.jl")
1516
include("Ops.jl")

src/Compiler.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,8 @@ function optimization_passes(;
838838
"divide_sqrt_to_multiply_rsqrt<16>",
839839
"associative_binary_op_reordering<1>",
840840
"transpose_broadcast_in_dim_to_broadcast_in_dim<16>",
841-
"scatter_indices_are_unique",
841+
# XXX: needs upstream fix
842+
# "scatter_indices_are_unique",
842843
"replace_neg_add_with_subtract",
843844
"binop_const_simplify",
844845
"not_select_simplify",

src/Ops.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,8 @@ end
17341734
scatter_indices_batching_dims::Vector{Int64},
17351735
scatter_dims_to_operand_dims::Vector{Int64},
17361736
index_vector_dim::Int64,
1737+
unique_indices::Union{Bool,Nothing}=nothing,
1738+
indices_are_sorted::Union{Bool,Nothing}=nothing,
17371739
location=mlir_stacktrace("scatter", @__FILE__, @__LINE__),
17381740
) where {T,N}
17391741
scatter_indices = subtract(
@@ -1768,6 +1770,8 @@ end
17681770
update_computation,
17691771
scatter_dimension_numbers,
17701772
result_0=[mlir_type(TracedRArray{T,N}, size(d)) for d in dest],
1773+
unique_indices,
1774+
indices_are_sorted,
17711775
location,
17721776
)
17731777

test/nn/nnlib.jl

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using NNlib, Reactant, Enzyme
2+
using Statistics
23

34
@testset "Activation Functions" begin
45
sumabs2(f, x) = sum(abs2, f.(x))
@@ -381,6 +382,255 @@ end
381382
end
382383
end
383384

385+
# Adapted from https://github.com/FluxML/NNlib.jl/blob/1468582c4db5f18149cc8fff6fb4633c5debe5c5/test/testsuite/scatter.jl#L108
386+
@testset "NNlib scatter" begin
387+
function test_scatter(dsts, srcs, idxs, res; dims)
388+
@testset "scatter Float32 $op" for op in (+, -, max, min, *, /, mean)
389+
for idx in values(idxs), dim in dims
390+
dst = copy(dsts[dim])
391+
target_y = res[(op, dim, true)]
392+
src = srcs[(dim, true)]
393+
if op == /
394+
src = src .* 2.0f0
395+
end
396+
397+
y1 = @jit(
398+
NNlib.scatter!(
399+
op, Reactant.to_rarray(dst), Reactant.to_rarray(src), idx
400+
)
401+
)
402+
@test y1 target_y
403+
@test y1 isa ConcreteRArray{Float32,ndims(dst)}
404+
@test size(y1) == size(dsts[dim])
405+
dst = copy(dsts[dim])
406+
y2 = @jit(
407+
NNlib.scatter!(
408+
op,
409+
Reactant.to_rarray(dst),
410+
Reactant.to_rarray(src),
411+
Reactant.to_rarray(idx),
412+
)
413+
)
414+
@test y2 target_y
415+
@test y2 isa ConcreteRArray{Float32,ndims(dst)}
416+
@test size(y2) == size(dsts[dim])
417+
418+
target_y = res[(op, dim, false)]
419+
src = srcs[(dim, false)]
420+
if op == /
421+
src = src .* 2.0f0
422+
end
423+
424+
y3 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), idx))
425+
@test y3 target_y
426+
@test y3 isa ConcreteRArray{Float32,ndims(dst)}
427+
@test size(y3) == size(dsts[dim])
428+
y4 = @jit(
429+
NNlib.scatter(
430+
op,
431+
Reactant.to_rarray(src),
432+
Reactant.to_rarray(idx);
433+
dstsize=size(dsts[dim]),
434+
)
435+
)
436+
@test y4 target_y
437+
@test y4 isa ConcreteRArray{Float32,ndims(dst)}
438+
@test size(y4) == size(dsts[dim])
439+
440+
ridx = Reactant.to_rarray(idx)
441+
if ridx isa Reactant.AbstractConcreteArray
442+
@test_throws ArgumentError @jit(
443+
NNlib.scatter(op, Reactant.to_rarray(src), ridx)
444+
)
445+
else
446+
y5 = @jit(NNlib.scatter(op, Reactant.to_rarray(src), ridx))
447+
@test y5 target_y
448+
@test y5 isa ConcreteRArray{Float32,ndims(dst)}
449+
@test size(y5) == size(dsts[dim])
450+
end
451+
end
452+
end
453+
end
454+
455+
@testset "scatter 1d src, 1d index => 1d output" begin
456+
#! format: off
457+
dsts = Dict(
458+
0 => Float32[3, 4, 5, 6, 7]
459+
)
460+
461+
srcs = Dict(
462+
(0, true) => ones(Float32, 5),
463+
(0, false) => collect(Float32, 1:5),
464+
)
465+
466+
idxs = Dict(
467+
:int => [4, 2, 1, 5, 3],
468+
:tup => [(4,), (2,), (1,), (5,), (3,)],
469+
:car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]),
470+
)
471+
472+
res = Dict(
473+
(+, 0, true) => Float32[4, 5, 6, 7, 8],
474+
(+, 0, false) => Float32[3, 2, 5, 1, 4],
475+
476+
(-, 0, true) => Float32[2, 3, 4, 5, 6],
477+
(-, 0, false) => Float32[-3, -2, -5, -1, -4],
478+
479+
(max, 0, true) => Float32[3, 4, 5, 6, 7],
480+
(max, 0, false) => Float32[3, 2, 5, 1, 4],
481+
482+
(min, 0, true) => Float32[1, 1, 1, 1, 1],
483+
(min, 0, false) => Float32[3, 2, 5, 1, 4],
484+
485+
(*, 0, true) => Float32[3, 4, 5, 6, 7],
486+
(*, 0, false) => Float32[3, 2, 5, 1, 4],
487+
488+
(/, 0, true) => Float32[1.5, 2.0, 2.5, 3.0, 3.5],
489+
(/, 0, false) => Float32[1//6, 1//4, 1//10, 1//2, 1//8],
490+
491+
(mean, 0, true) => Float32[4, 5, 6, 7, 8],
492+
(mean, 0, false) => Float32[3, 2, 5, 1, 4],
493+
)
494+
#! format: on
495+
test_scatter(dsts, srcs, idxs, res; dims=[0])
496+
end
497+
498+
@testset "scatter 2d src, 1d index => 2d output" begin
499+
#! format: off
500+
dsts = Dict(
501+
0 => Float32[3 3 4 4 5
502+
5 5 6 6 7]
503+
)
504+
505+
srcs = Dict(
506+
(0, true) => ones(Float32, 2, 5),
507+
(0, false) => ones(Float32, 2) * collect(1:5)',
508+
)
509+
510+
idxs = Dict(
511+
:int => [4, 2, 1, 5, 3],
512+
:tup => [(4,), (2,), (1,), (5,), (3,)],
513+
:car => CartesianIndex.([(4,), (2,), (1,), (5,), (3,)]),
514+
)
515+
516+
res = Dict(
517+
(+, 0, true) => Float32[4 4 5 5 6;
518+
6 6 7 7 8],
519+
(+, 0, false) => Float32[3 2 5 1 4;
520+
3 2 5 1 4],
521+
522+
(-, 0, true) => Float32[2 2 3 3 4;
523+
4 4 5 5 6],
524+
(-, 0, false) => Float32[-3 -2 -5 -1 -4;
525+
-3 -2 -5 -1 -4],
526+
527+
(max, 0, true) => Float32[3 3 4 4 5;
528+
5 5 6 6 7],
529+
(max, 0, false) => Float32[3 2 5 1 4;
530+
3 2 5 1 4],
531+
532+
(min, 0, true) => Float32[1 1 1 1 1;
533+
1 1 1 1 1],
534+
(min, 0, false) => Float32[3 2 5 1 4;
535+
3 2 5 1 4],
536+
537+
(*, 0, true) => Float32[3 3 4 4 5;
538+
5 5 6 6 7],
539+
(*, 0, false) => Float32[3 2 5 1 4;
540+
3 2 5 1 4],
541+
542+
(/, 0, true) => Float32[1.5 1.5 2.0 2.0 2.5;
543+
2.5 2.5 3.0 3.0 3.5],
544+
(/, 0, false) => Float32[1//6 1//4 1//10 1//2 1//8;
545+
1//6 1//4 1//10 1//2 1//8],
546+
547+
(mean, 0, true) => Float32[4 4 5 5 6;
548+
6 6 7 7 8],
549+
(mean, 0, false) => Float32[3 2 5 1 4;
550+
3 2 5 1 4],
551+
)
552+
#! format: on
553+
test_scatter(dsts, srcs, idxs, res; dims=[0])
554+
end
555+
556+
@testset "scatter 2d+3d src, 2d index => 1d+2d output" begin
557+
#! format: off
558+
dsts = Dict(
559+
0 => Float32[3, 4, 5, 6, 7],
560+
1 => Float32[3 3 4 4 5;
561+
5 5 6 6 7],
562+
)
563+
564+
srcs = Dict(
565+
(0, true) => ones(Float32, 3, 4),
566+
(0, false) => ones(Float32, 3) * collect(1:4)',
567+
(1, true) => ones(Float32, 2, 3, 4),
568+
(1, false) => Float32[1, 2] .* reshape(ones(Float32, 3) * collect(1:4)', 1,3,4),
569+
)
570+
571+
idxs = Dict(
572+
:int => [1 2 3 4;
573+
4 2 1 3;
574+
3 5 5 3],
575+
:tup => [(1,) (2,) (3,) (4,);
576+
(4,) (2,) (1,) (3,);
577+
(3,) (5,) (5,) (3,)],
578+
:car => CartesianIndex.(
579+
[(1,) (2,) (3,) (4,);
580+
(4,) (2,) (1,) (3,);
581+
(3,) (5,) (5,) (3,)]),
582+
)
583+
584+
res = Dict(
585+
(+, 0, true) => Float32[5, 6, 9, 8, 9],
586+
(+, 1, true) => Float32[5 5 8 6 7;
587+
7 7 10 8 9],
588+
(+, 0, false) => Float32[4, 4, 12, 5, 5],
589+
(+, 1, false) => Float32[4 4 12 5 5;
590+
8 8 24 10 10],
591+
(-, 0, true) => Float32[1, 2, 1, 4, 5],
592+
(-, 1, true) => Float32[1 1 0 2 3;
593+
3 3 2 4 5],
594+
(-, 0, false) => Float32[-4, -4, -12, -5, -5],
595+
(-, 1, false) => Float32[-4 -4 -12 -5 -5;
596+
-8 -8 -24 -10 -10],
597+
(max, 0, true) => Float32[3, 4, 5, 6, 7],
598+
(max, 1, true) => Float32[3 3 4 4 5;
599+
5 5 6 6 7],
600+
(max, 0, false) => Float32[3, 2, 4, 4, 3],
601+
(max, 1, false) => Float32[3 2 4 4 3;
602+
6 4 8 8 6],
603+
(min, 0, true) => Float32[1, 1, 1, 1, 1],
604+
(min, 1, true) => Float32[1 1 1 1 1;
605+
1 1 1 1 1],
606+
(min, 0, false) => Float32[1, 2, 1, 1, 2],
607+
(min, 1, false) => Float32[1 2 1 1 2;
608+
2 4 2 2 4],
609+
(*, 0, true) => Float32[3, 4, 5, 6, 7],
610+
(*, 1, true) => Float32[3 3 4 4 5;
611+
5 5 6 6 7],
612+
(*, 0, false) => Float32[3, 4, 48, 4, 6],
613+
(*, 1, false) => Float32[3 4 48 4 6;
614+
12 16 768 16 24],
615+
(/, 0, true) => Float32[0.75, 1., 0.3125, 1.5, 1.75],
616+
(/, 1, true) => Float32[0.75 0.75 0.25 1. 1.25;
617+
1.25 1.25 0.375 1.5 1.75],
618+
(/, 0, false) => Float32[1//12, 1//16, 1//768, 1//16, 1//24],
619+
(/, 1, false) => Float32[1//12 1//16 1//768 1//16 1//24;
620+
1//48 1//64 1//12288 1//64 1//96],
621+
(mean, 0, true) => Float32[4., 5., 6., 7., 8.],
622+
(mean, 1, true) => Float32[4. 4. 5. 5. 6.;
623+
6. 6. 7. 7. 8.],
624+
(mean, 0, false) => Float32[2, 2, 3, 2.5, 2.5],
625+
(mean, 1, false) => Float32[2. 2. 3. 2.5 2.5;
626+
4. 4. 6. 5. 5.],
627+
)
628+
#! format: on
629+
630+
test_scatter(dsts, srcs, idxs, res; dims=[0, 1])
631+
end
632+
end
633+
384634
@testset "∇conv(D = $ndim)" for ndim in 1:3
385635
x_spatial_dim = 4
386636
batch_size = 2

0 commit comments

Comments
 (0)