Skip to content

Commit 3ac8c93

Browse files
authored
feat: better codegen with gather/scatter (#1283)
* feat: don't flatten indices for getindex * fix: repeats * fix: indexing consistency * fix: avoid reshape * feat: better codegen for NNlib.gather * revert: reshape change * feat: strided indexing * fix: better diag codegen * fix: missing singletons * fix: more checks * fix: lower to dynamic_slice * fix: temporarily disable slice_elementwise * fix: one-elem cartesian index * feat: setindex as scatter * chore: bump jll
1 parent f5b3cc2 commit 3ac8c93

File tree

6 files changed

+292
-98
lines changed

6 files changed

+292
-98
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -430,64 +430,43 @@ function NNlib.pad_constant(
430430
end
431431

432432
# Gather
433-
# XXX: reevaluate this manual optimization once
434-
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
435-
function NNlib.gather!(
436-
dst::AnyTracedRArray{T1,2},
437-
src::AnyTracedRArray{T2,2},
438-
idxs::Union{AbstractUnitRange{<:Number}},
439-
) where {T1,T2}
440-
set_mlir_data!(dst, get_mlir_data(src[:, idxs]))
433+
function NNlib.gather!(dst::AnyTracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
434+
n_dims = NNlib.scatter_dims(src, dst, idxs)
435+
res = _nnlib_gather_impl(src, _stack_indices(idxs), n_dims)
436+
set_mlir_data!(dst, get_mlir_data(res))
441437
return dst
442438
end
443439

444440
function NNlib.gather!(
445-
dst::AnyTracedRArray{T1,2}, src::AnyTracedRArray{T2,2}, idxs::AbstractVector{<:Number}
446-
) where {T1,T2}
447-
dims = NNlib.scatter_dims(src, dst, idxs)
448-
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
449-
idxs = get_mlir_data(TracedUtils.promote_to(TracedRArray{Int,1}, idxs) .- 1)
450-
slice_sizes = get_mlir_data(
451-
TracedUtils.promote_to(TracedRArray{Int,1}, [size(src, 1), 1])
452-
)
453-
454-
#! format: off
455-
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
456-
MLIR.IR.context(),
457-
Int64(1), Int64[0],
458-
Int64(1), Int64[1],
459-
Int64(0), Int64[],
460-
Int64(0), Int64[],
461-
Int64(1), Int64[1],
462-
Int64(1)
463-
)
464-
#! format: on
465-
466-
res = MLIR.IR.result(
467-
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
468-
get_mlir_data(src), idxs, slice_sizes; dimension_numbers
469-
),
470-
1,
471-
)
472-
set_mlir_data!(dst, res)
441+
dst::AnyTracedRArray, src::AnyTracedRArray, idxs::AbstractArray{<:Number}
442+
)
443+
n_dims = NNlib.scatter_dims(src, dst, idxs)
444+
res = _nnlib_gather_impl(src, reshape(idxs, 1, size(idxs)...), n_dims)
445+
set_mlir_data!(dst, get_mlir_data(res))
473446
return dst
474447
end
475448

476-
# XXX: For performance to use `stablehlo.dynamic_gather` or atleast use traced loop
477-
# instead of unrolling the loop (the case for AbstractArray can just use
478-
# `stablehlo.gather`). See above for the special case implementation that is optimized.
479-
function NNlib.gather!(dst::AnyTracedRArray, src::AnyTracedRArray, idxs::AbstractArray)
480-
@warn "Using fallback implementation of `gather!` for using `stablehlo.dynamic_slice`. \
481-
This case is not optimized and will be slow." maxlog = 1
482-
dims = NNlib.scatter_dims(src, dst, idxs)
483-
colons = ntuple(Returns(Colon()), dims)
484-
start_sizes = ntuple(Base.Fix1(size, src), dims)
485-
results = map(CartesianIndices(idxs)) do k
486-
res = @allowscalar src[colons..., Tuple(idxs[k])...]
487-
res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,)))
488-
return reshape(res, start_sizes..., :)
449+
_stack_indices(idxs::AbstractArray) = stack(idxs)
450+
function _stack_indices(idxs::AbstractArray{<:CartesianIndex})
451+
stacked_idxs = similar(idxs, Int, length(first(idxs)), size(idxs)...)
452+
for k in CartesianIndices(idxs)
453+
stacked_idxs[:, k.I...] .= idxs[k].I
489454
end
490-
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
491-
set_mlir_data!(dst, get_mlir_data(res))
492-
return dst
455+
return stacked_idxs
456+
end
457+
458+
function _nnlib_gather_impl(src::AnyTracedRArray, idxs::AbstractArray, n_dims::Int)
459+
idxs = TracedUtils.promote_to(TracedRArray{Int,ndims(idxs)}, idxs)
460+
n_idxs = size(idxs, 1)
461+
return Ops.gather(
462+
src,
463+
idxs;
464+
offset_dims=collect(Int64, 1:n_dims),
465+
collapsed_slice_dims=collect(Int64, (n_dims + 1):ndims(src)),
466+
operand_batching_dims=Int64[],
467+
start_indices_batching_dims=Int64[],
468+
start_index_map=collect(Int64, (ndims(src) - n_idxs + 1):ndims(src)),
469+
index_vector_dim=1,
470+
slice_sizes=Int64[size(src)[1:n_dims]..., ones(Int64, ndims(src) - n_dims)...],
471+
)
493472
end

src/Ops.jl

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -532,16 +532,24 @@ end
532532

533533
@noinline function slice(
534534
x::TracedRArray{T,N},
535-
start_indices,
536-
limit_indices;
537-
strides=nothing,
535+
start_indices::Vector{<:Integer},
536+
limit_indices::Vector{<:Integer};
537+
strides::Union{Nothing,Vector{<:Integer}}=nothing,
538538
location=mlir_stacktrace("slice", @__FILE__, @__LINE__),
539539
) where {T,N}
540540
start_indices = start_indices .- 1
541541
limit_indices = limit_indices
542-
rsize = limit_indices .- start_indices
543-
@assert all(rsize .> 0) "Invalid slice dimensions"
542+
@assert all(Base.Fix2(, 0), start_indices) "Invalid start indices: $(start_indices)"
543+
@assert all(s < l for (s, l) in zip(start_indices, limit_indices)) "Invalid slice indices: $(start_indices), $(limit_indices)"
544+
544545
strides = isnothing(strides) ? ones(Int64, N) : strides
546+
@assert all(s > 0 for s in strides) "Invalid strides: $(strides)"
547+
rsize = [
548+
length((start + 1):st:stop) for
549+
(start, stop, st) in zip(start_indices, limit_indices, strides)
550+
]
551+
@assert all(rsize .> 0) "Invalid slice dimensions"
552+
545553
res = MLIR.IR.result(
546554
stablehlo.slice(
547555
x.mlir_data;
@@ -1732,18 +1740,18 @@ instead.
17321740
[updates];
17331741
update_computation,
17341742
update_window_dims=Int64[],
1735-
inserted_window_dims=collect(Int64, 0:(N - 1)),
1743+
inserted_window_dims=collect(Int64, 1:N),
17361744
input_batching_dims=Int64[],
17371745
scatter_indices_batching_dims=Int64[],
1738-
scatter_dims_to_operand_dims=collect(Int64, 0:(N - 1)),
1739-
index_vector_dim=Int64(1),
1746+
scatter_dims_to_operand_dims=collect(Int64, 1:N),
1747+
index_vector_dim=Int64(2),
17401748
location,
17411749
)[1]
17421750
end
17431751

17441752
@noinline function scatter(
17451753
dest::Vector{TracedRArray{T,N}},
1746-
scatter_indices::TracedRArray{Int64,2},
1754+
scatter_indices::TracedRArray{Int64},
17471755
updates::Vector{<:TracedRArray{T}};
17481756
update_computation::MLIR.IR.Region,
17491757
update_window_dims::Vector{Int64},
@@ -1758,6 +1766,13 @@ end
17581766
scatter_indices, fill(Int64(1), size(scatter_indices)); location
17591767
)
17601768

1769+
update_window_dims = update_window_dims .- 1
1770+
inserted_window_dims = inserted_window_dims .- 1
1771+
input_batching_dims = input_batching_dims .- 1
1772+
scatter_indices_batching_dims = scatter_indices_batching_dims .- 1
1773+
scatter_dims_to_operand_dims = scatter_dims_to_operand_dims .- 1
1774+
index_vector_dim -= 1
1775+
17611776
#! format: off
17621777
scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
17631778
MLIR.IR.context(),
@@ -1813,11 +1828,11 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
18131828
src,
18141829
gather_indices;
18151830
offset_dims=Int64[1],
1816-
collapsed_slice_dims=collect(Int64, 0:(N - 2)),
1831+
collapsed_slice_dims=collect(Int64, 1:(N - 1)),
18171832
operand_batching_dims=Int64[],
18181833
start_indices_batching_dims=Int64[],
1819-
start_index_map=collect(Int64, 0:(N - 1)),
1820-
index_vector_dim=Int64(1),
1834+
start_index_map=collect(Int64, 1:N),
1835+
index_vector_dim=Int64(2),
18211836
slice_sizes=ones(Int64, N),
18221837
indices_are_sorted=false,
18231838
location,
@@ -1828,7 +1843,7 @@ end
18281843

18291844
@noinline function gather(
18301845
src::TracedRArray{T,N},
1831-
gather_indices::TracedRArray{Int64,2};
1846+
gather_indices::TracedRArray{Int64};
18321847
offset_dims::Vector{Int64},
18331848
collapsed_slice_dims::Vector{Int64},
18341849
operand_batching_dims::Vector{Int64},
@@ -1843,6 +1858,13 @@ end
18431858
gather_indices, fill(Int64(1), size(gather_indices)); location
18441859
)
18451860

1861+
offset_dims = offset_dims .- 1
1862+
start_indices_batching_dims = start_indices_batching_dims .- 1
1863+
start_index_map = start_index_map .- 1
1864+
operand_batching_dims = operand_batching_dims .- 1
1865+
collapsed_slice_dims = collapsed_slice_dims .- 1
1866+
index_vector_dim -= 1
1867+
18461868
#! format: off
18471869
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
18481870
MLIR.IR.context(),

src/TracedRArray.jl

Lines changed: 99 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -140,20 +140,28 @@ function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {
140140
return _getindex_cartesian(a, indices)
141141
end
142142

143-
_isone(x) = isone(x)
144-
_isone(::CartesianIndex) = false
145-
146-
__contiguous_indices(::Base.LogicalIndex) = false
147-
__contiguous_indices(x) = all(_isone, diff(x))
148-
149143
function _getindex_linear(a::TracedRArray{T,N}, indices::AbstractArray) where {T,N}
150-
if !(indices isa Reactant.TracedType) && __contiguous_indices(vec(indices))
151-
a_flat = materialize_traced_array(vec(a))
152-
indices_flat = vec(indices)
153-
return Ops.reshape(
154-
Ops.dynamic_slice(a_flat, [first(indices_flat)], [length(indices_flat)]),
155-
collect(size(indices)),
156-
)
144+
if !(indices isa Reactant.TracedType)
145+
if length(indices) == 1 && first(indices) isa CartesianIndex
146+
# fast-path else we will end up with a gather
147+
return TracedUtils.broadcast_to_size(
148+
@allowscalar(_getindex_cartesian(a, first(indices))), (1,)
149+
)
150+
end
151+
stride = TracedUtils._get_slice_stride(vec(indices))
152+
if stride > 0
153+
a_flat = materialize_traced_array(vec(a))
154+
indices_flat = vec(indices)
155+
return Ops.reshape(
156+
Ops.slice(
157+
a_flat,
158+
Int64[first(indices_flat)],
159+
Int64[last(indices_flat)];
160+
strides=Int64[stride],
161+
),
162+
collect(Int64, size(indices)),
163+
)
164+
end
157165
end
158166

159167
if !(indices isa TracedRArray)
@@ -180,14 +188,21 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
180188
indices = Base.to_indices(a, indices)
181189

182190
use_gather_getindex = false
191+
use_dynamic_slice = false
192+
strides = Int64[]
183193
for idxs in indices
184-
idxs isa Number && continue
194+
if idxs isa Number
195+
idxs isa TracedRNumber && (use_dynamic_slice = true)
196+
push!(strides, 1)
197+
continue
198+
end
185199
if idxs isa Reactant.TracedType
186200
use_gather_getindex = true
187201
break
188202
end
189-
contiguous = __contiguous_indices(vec(idxs))
190-
if typeof(contiguous) <: Bool && !contiguous
203+
stride = TracedUtils._get_slice_stride(vec(idxs))
204+
push!(strides, stride)
205+
if stride 0 || (use_dynamic_slice && stride != 1)
191206
use_gather_getindex = true
192207
break
193208
end
@@ -200,18 +215,37 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
200215
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
201216
end
202217

203-
indices, integer_indices, result_size, preddim_result_size, _ = TracedUtils.traced_indices(
204-
indices...
205-
)
206-
res = Ops.reshape(
207-
Ops.gather_getindex(a, generate_index_list(indices...)), preddim_result_size
218+
gather_dims = TracedUtils.indices_to_gather_dims(indices...)
219+
220+
return Ops.reshape(
221+
Ops.transpose(
222+
Ops.reshape(
223+
Ops.gather(
224+
a,
225+
gather_dims.start_indices;
226+
gather_dims.offset_dims,
227+
gather_dims.collapsed_slice_dims,
228+
operand_batching_dims=Int64[],
229+
start_indices_batching_dims=Int64[],
230+
gather_dims.start_index_map,
231+
gather_dims.index_vector_dim,
232+
gather_dims.slice_sizes,
233+
),
234+
gather_dims.gather_reshape_shape,
235+
),
236+
gather_dims.permutation,
237+
),
238+
gather_dims.result_shape,
208239
)
209-
isempty(integer_indices) ||
210-
(res = materialize_traced_array(dropdims(res; dims=integer_indices)))
211-
return Ops.reshape(res, result_size)
212240
end
213241

214-
x = Ops.dynamic_slice(a, [first.(indices)...], [length.(indices)...])
242+
if use_dynamic_slice
243+
@assert all(isone, strides) "This should not happen, please report a bug"
244+
x = Ops.dynamic_slice(a, [first.(indices)...], [length.(indices)...])
245+
else
246+
x = Ops.slice(a, [first.(indices)...], [last.(indices)...]; strides)
247+
end
248+
215249
ddims = findall(indices) do idx
216250
return idx isa Integer || idx isa TracedRNumber{<:Integer}
217251
end
@@ -313,7 +347,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, index::CartesianIndex{N}) where
313347
end
314348

315349
function _setindex_linear!(a::TracedRArray{T,N}, v, indices::AbstractArray) where {T,N}
316-
if !(indices isa Reactant.TracedType) && __contiguous_indices(vec(indices))
350+
if !(indices isa Reactant.TracedType) && TracedUtils.__contiguous_indices(vec(indices))
317351
res = Ops.reshape(
318352
Ops.dynamic_update_slice(
319353
materialize_traced_array(vec(a)),
@@ -371,7 +405,7 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
371405
use_scatter_setindex = true
372406
break
373407
end
374-
contiguous = __contiguous_indices(idxs)
408+
contiguous = TracedUtils.__contiguous_indices(idxs)
375409
if typeof(contiguous) <: Bool && !contiguous
376410
use_scatter_setindex = true
377411
break
@@ -384,9 +418,44 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
384418
if any(i -> unwrapped_eltype(i) <: Bool, indices)
385419
error("Boolean indexing with TracedRArrays isn't fully supported yet.")
386420
end
387-
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
388-
indices_list = generate_index_list(indices_list...)
389-
res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
421+
422+
gather_dims = TracedUtils.indices_to_gather_dims(indices...)
423+
424+
v = Ops.convert(
425+
TracedRArray{T,ndims(v)},
426+
TracedUtils.promote_to(TracedRArray{unwrapped_eltype(v),ndims(v)}, v),
427+
)
428+
429+
updates = Ops.transpose(v, invperm(gather_dims.permutation))
430+
n_collapsed = length(gather_dims.collapsed_slice_dims)
431+
updates_shape = Int64[
432+
prod(size(updates)[1:n_collapsed]), size(updates)[(n_collapsed + 1):end]...
433+
]
434+
updates = Ops.reshape(updates, updates_shape)
435+
436+
# simply set the 2nd block argument as a result
437+
update_computation = MLIR.IR.Region()
438+
block = MLIR.IR.Block(
439+
[Ops.mlir_type(TracedRNumber{T}), Ops.mlir_type(TracedRNumber{T})],
440+
[MLIR.IR.Location(), MLIR.IR.Location()],
441+
)
442+
return_op = MLIR.Dialects.stablehlo.return_([MLIR.IR.argument(block, 2)])
443+
MLIR.IR.rmfromparent!(return_op)
444+
push!(block, return_op)
445+
pushfirst!(update_computation, block)
446+
447+
res = Ops.scatter(
448+
[a],
449+
gather_dims.start_indices,
450+
[updates];
451+
update_computation,
452+
update_window_dims=gather_dims.offset_dims,
453+
inserted_window_dims=gather_dims.collapsed_slice_dims,
454+
input_batching_dims=Int64[],
455+
scatter_indices_batching_dims=Int64[],
456+
scatter_dims_to_operand_dims=gather_dims.start_index_map,
457+
index_vector_dim=gather_dims.index_vector_dim,
458+
)[1]
390459
set_mlir_data!(a, get_mlir_data(res))
391460
return v
392461
end

0 commit comments

Comments
 (0)