Skip to content

Commit be5f0a3

Browse files
perf: special case top_k with k=1 to use reduction (#1439)
* perf: special case top_k with k=1 to use reduction * Update src/Ops.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix: handle the equality case --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d63e4c0 commit be5f0a3

File tree

2 files changed

+81
-20
lines changed

2 files changed

+81
-20
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.141"
4+
version = "0.2.142"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/Ops.jl

Lines changed: 80 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,15 @@ end
12251225
location=mlir_stacktrace("top_k", @__FILE__, @__LINE__),
12261226
) where {T,N}
12271227
@assert 1 <= dimension <= N
1228+
1229+
# XLA codegen for top.k is extremely sub-optimal. For special cases we can bypass that
1230+
if k isa Integer && k == 1
1231+
values, indices = argmax(x; dimension, location)
1232+
return (;
1233+
values, indices=add(indices, fill(Int64(1), Tuple(size(indices))); location)
1234+
)
1235+
end
1236+
12281237
if dimension != N # chlo.top_k performs the operation along the last dimension
12291238
pdims = collect(Int64, 1:N)
12301239
pdims[dimension] = N
@@ -1251,13 +1260,41 @@ end
12511260
return (; values, indices)
12521261
end
12531262

1263+
@noinline function argmax(
1264+
x::TracedRArray{T,N};
1265+
dimension::Integer=N,
1266+
location=mlir_stacktrace("argmax", @__FILE__, @__LINE__),
1267+
) where {T,N}
1268+
values, indices = reduce(
1269+
TracedRArray[
1270+
x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location)
1271+
],
1272+
TracedRNumber[
1273+
Reactant.TracedUtils.promote_to(TracedRNumber{T}, typemin(T)),
1274+
Reactant.TracedUtils.promote_to(TracedRNumber{Int64}, -1),
1275+
],
1276+
[dimension],
1277+
function (a₁, i₁, a₂, i₂)
1278+
cond = a₁ a₂
1279+
return ifelse(cond, a₁, a₂), ifelse(cond, i₁, i₂)
1280+
end;
1281+
location,
1282+
)
1283+
new_shape = collect(Int64, size(x))
1284+
new_shape[dimension] = 1
1285+
return (
1286+
Ops.reshape(values, new_shape; location), Ops.reshape(indices, new_shape; location)
1287+
)
1288+
end
1289+
12541290
@noinline function iota(
12551291
T::Type,
12561292
shape::Vector{Int};
12571293
iota_dimension,
12581294
location=mlir_stacktrace("iota", @__FILE__, @__LINE__),
12591295
)
12601296
N = length(shape)
1297+
@assert 0 < iota_dimension <= N
12611298
output = mlir_type(TracedRArray{T,N}, shape)
12621299
iota_dimension = MLIR.IR.Attribute(iota_dimension - 1)
12631300
res = MLIR.IR.result(stablehlo.iota(; output, iota_dimension, location))
@@ -2631,24 +2668,30 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
26312668
end
26322669
end
26332670

2634-
function _construct_reduce_function(f::F, ::Type{T}) where {F,T}
2671+
function _construct_reduce_function(f::F, Ts::Type...) where {F}
2672+
inputs_1 = [Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0) for T in Ts]
2673+
inputs_2 = [Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0) for T in Ts]
26352674
func =
26362675
Reactant.TracedUtils.make_mlir_fn(
26372676
f,
2638-
(
2639-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2640-
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
2641-
),
2677+
(inputs_1..., inputs_2...),
26422678
(),
26432679
"reduce_fn" * string(f),
26442680
false;
26452681
args_in_result=:none,
26462682
return_dialect=:stablehlo,
26472683
).f
2684+
26482685
@assert MLIR.IR.nregions(func) == 1
26492686
ftype_attr = MLIR.IR.attr(func, "function_type")
26502687
ftype = MLIR.IR.Type(ftype_attr)
2651-
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType(Int[], MLIR.IR.Type(T)) "$(fn) return type is not of tensor<$(T)>"
2688+
2689+
@assert MLIR.IR.nresults(ftype) == length(Ts)
2690+
for i in 1:MLIR.IR.nresults(ftype)
2691+
tType = MLIR.IR.TensorType(Int[], MLIR.IR.Type(Ts[i]))
2692+
@assert MLIR.IR.result(ftype, i) == tType "$(f) return type $(i) is not of \
2693+
tensor<$(Ts[i])>"
2694+
end
26522695

26532696
fn = MLIR.IR.Region()
26542697
MLIR.API.mlirRegionTakeBody(fn, MLIR.IR.region(func, 1))
@@ -2703,23 +2746,41 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
27032746
x::TracedRArray{T},
27042747
init_values::TracedRNumber{T},
27052748
dimensions::Vector{Int},
2706-
fn::Function,
2749+
fn::F;
27072750
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
2708-
) where {T}
2709-
reduced_shape = Tuple(deleteat!(collect(Int64, size(x)), dimensions))
2751+
) where {T,F}
2752+
return only(reduce([x], [init_values], dimensions, fn; location))
2753+
end
27102754

2711-
res = MLIR.IR.result(
2712-
stablehlo.reduce(
2713-
[x.mlir_data],
2714-
[init_values.mlir_data];
2715-
result_0=[mlir_type(TracedRArray{T,length(reduced_shape)}, reduced_shape)],
2716-
dimensions=MLIR.IR.Attribute(dimensions .- 1),
2717-
body=_construct_reduce_function(fn, T),
2718-
location=location,
2719-
),
2755+
@noinline function reduce(
2756+
xs::Vector{<:TracedRArray},
2757+
init_values::Vector{<:TracedRNumber},
2758+
dimensions::Vector{Int},
2759+
fn::F;
2760+
location=mlir_stacktrace("reduce", @__FILE__, @__LINE__),
2761+
) where {F}
2762+
@assert allequal(size.(xs)) "All input arrays must have the same size."
2763+
2764+
reduced_shape = Tuple(deleteat!(collect(Int64, size(xs[1])), dimensions))
2765+
2766+
op = stablehlo.reduce(
2767+
[x.mlir_data for x in xs],
2768+
[init_value.mlir_data for init_value in init_values];
2769+
result_0=[
2770+
mlir_type(
2771+
TracedRArray{unwrapped_eltype(x),length(reduced_shape)}, reduced_shape
2772+
) for x in xs
2773+
],
2774+
dimensions=MLIR.IR.Attribute(dimensions .- 1),
2775+
body=_construct_reduce_function(fn, [unwrapped_eltype(x) for x in xs]...),
2776+
location,
27202777
)
27212778

2722-
return TracedRArray{T,length(reduced_shape)}((), res, reduced_shape)
2779+
return [
2780+
TracedRArray{unwrapped_eltype(xs[i]),length(reduced_shape)}(
2781+
(), MLIR.IR.result(op, i), reduced_shape
2782+
) for i in 1:MLIR.IR.nresults(op)
2783+
]
27232784
end
27242785

27252786
@noinline function dynamic_update_slice(

0 commit comments

Comments
 (0)