@@ -1225,6 +1225,15 @@ end
12251225 location= mlir_stacktrace (" top_k" , @__FILE__ , @__LINE__ ),
12261226) where {T,N}
12271227 @assert 1 <= dimension <= N
1228+
1229+ # XLA codegen for top.k is extremely sub-optimal. For special cases we can bypass that
1230+ if k isa Integer && k == 1
1231+ values, indices = argmax (x; dimension, location)
1232+ return (;
1233+ values, indices= add (indices, fill (Int64 (1 ), Tuple (size (indices))); location)
1234+ )
1235+ end
1236+
12281237 if dimension != N # chlo.top_k performs the operation along the last dimension
12291238 pdims = collect (Int64, 1 : N)
12301239 pdims[dimension] = N
@@ -1251,13 +1260,41 @@ end
12511260 return (; values, indices)
12521261end
12531262
1263+ @noinline function argmax (
1264+ x:: TracedRArray{T,N} ;
1265+ dimension:: Integer = N,
1266+ location= mlir_stacktrace (" argmax" , @__FILE__ , @__LINE__ ),
1267+ ) where {T,N}
1268+ values, indices = reduce (
1269+ TracedRArray[
1270+ x, iota (Int64, collect (Int64, size (x)); iota_dimension= dimension, location)
1271+ ],
1272+ TracedRNumber[
1273+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, typemin (T)),
1274+ Reactant. TracedUtils. promote_to (TracedRNumber{Int64}, - 1 ),
1275+ ],
1276+ [dimension],
1277+ function (a₁, i₁, a₂, i₂)
1278+ cond = a₁ ≥ a₂
1279+ return ifelse (cond, a₁, a₂), ifelse (cond, i₁, i₂)
1280+ end ;
1281+ location,
1282+ )
1283+ new_shape = collect (Int64, size (x))
1284+ new_shape[dimension] = 1
1285+ return (
1286+ Ops. reshape (values, new_shape; location), Ops. reshape (indices, new_shape; location)
1287+ )
1288+ end
1289+
12541290@noinline function iota (
12551291 T:: Type ,
12561292 shape:: Vector{Int} ;
12571293 iota_dimension,
12581294 location= mlir_stacktrace (" iota" , @__FILE__ , @__LINE__ ),
12591295)
12601296 N = length (shape)
1297+ @assert 0 < iota_dimension <= N
12611298 output = mlir_type (TracedRArray{T,N}, shape)
12621299 iota_dimension = MLIR. IR. Attribute (iota_dimension - 1 )
12631300 res = MLIR. IR. result (stablehlo. iota (; output, iota_dimension, location))
@@ -2631,24 +2668,30 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
26312668 end
26322669end
26332670
2634- function _construct_reduce_function (f:: F , :: Type{T} ) where {F,T}
2671+ function _construct_reduce_function (f:: F , Ts:: Type... ) where {F}
2672+ inputs_1 = [Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ) for T in Ts]
2673+ inputs_2 = [Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ) for T in Ts]
26352674 func =
26362675 Reactant. TracedUtils. make_mlir_fn (
26372676 f,
2638- (
2639- Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ),
2640- Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ),
2641- ),
2677+ (inputs_1... , inputs_2... ),
26422678 (),
26432679 " reduce_fn" * string (f),
26442680 false ;
26452681 args_in_result= :none ,
26462682 return_dialect= :stablehlo ,
26472683 ). f
2684+
26482685 @assert MLIR. IR. nregions (func) == 1
26492686 ftype_attr = MLIR. IR. attr (func, " function_type" )
26502687 ftype = MLIR. IR. Type (ftype_attr)
2651- @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType (Int[], MLIR. IR. Type (T)) " $(fn) return type is not of tensor<$(T) >"
2688+
2689+ @assert MLIR. IR. nresults (ftype) == length (Ts)
2690+ for i in 1 : MLIR. IR. nresults (ftype)
2691+ tType = MLIR. IR. TensorType (Int[], MLIR. IR. Type (Ts[i]))
2692+ @assert MLIR. IR. result (ftype, i) == tType " $(f) return type $(i) is not of \
2693+ tensor<$(Ts[i]) >"
2694+ end
26522695
26532696 fn = MLIR. IR. Region ()
26542697 MLIR. API. mlirRegionTakeBody (fn, MLIR. IR. region (func, 1 ))
@@ -2703,23 +2746,41 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
27032746 x:: TracedRArray{T} ,
27042747 init_values:: TracedRNumber{T} ,
27052748 dimensions:: Vector{Int} ,
2706- fn:: Function ,
2749+ fn:: F ;
27072750 location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2708- ) where {T}
2709- reduced_shape = Tuple (deleteat! (collect (Int64, size (x)), dimensions))
2751+ ) where {T,F}
2752+ return only (reduce ([x], [init_values], dimensions, fn; location))
2753+ end
27102754
2711- res = MLIR. IR. result (
2712- stablehlo. reduce (
2713- [x. mlir_data],
2714- [init_values. mlir_data];
2715- result_0= [mlir_type (TracedRArray{T,length (reduced_shape)}, reduced_shape)],
2716- dimensions= MLIR. IR. Attribute (dimensions .- 1 ),
2717- body= _construct_reduce_function (fn, T),
2718- location= location,
2719- ),
2755+ @noinline function reduce (
2756+ xs:: Vector{<:TracedRArray} ,
2757+ init_values:: Vector{<:TracedRNumber} ,
2758+ dimensions:: Vector{Int} ,
2759+ fn:: F ;
2760+ location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2761+ ) where {F}
2762+ @assert allequal (size .(xs)) " All input arrays must have the same size."
2763+
2764+ reduced_shape = Tuple (deleteat! (collect (Int64, size (xs[1 ])), dimensions))
2765+
2766+ op = stablehlo. reduce (
2767+ [x. mlir_data for x in xs],
2768+ [init_value. mlir_data for init_value in init_values];
2769+ result_0= [
2770+ mlir_type (
2771+ TracedRArray{unwrapped_eltype (x),length (reduced_shape)}, reduced_shape
2772+ ) for x in xs
2773+ ],
2774+ dimensions= MLIR. IR. Attribute (dimensions .- 1 ),
2775+ body= _construct_reduce_function (fn, [unwrapped_eltype (x) for x in xs]. .. ),
2776+ location,
27202777 )
27212778
2722- return TracedRArray {T,length(reduced_shape)} ((), res, reduced_shape)
2779+ return [
2780+ TracedRArray {unwrapped_eltype(xs[i]),length(reduced_shape)} (
2781+ (), MLIR. IR. result (op, i), reduced_shape
2782+ ) for i in 1 : MLIR. IR. nresults (op)
2783+ ]
27232784end
27242785
27252786@noinline function dynamic_update_slice (
0 commit comments