@@ -1225,6 +1225,15 @@ end
1225
1225
location= mlir_stacktrace (" top_k" , @__FILE__ , @__LINE__ ),
1226
1226
) where {T,N}
1227
1227
@assert 1 <= dimension <= N
1228
+
1229
+ # XLA codegen for top.k is extremely sub-optimal. For special cases we can bypass that
1230
+ if k isa Integer && k == 1
1231
+ values, indices = argmax (x; dimension, location)
1232
+ return (;
1233
+ values, indices= add (indices, fill (Int64 (1 ), Tuple (size (indices))); location)
1234
+ )
1235
+ end
1236
+
1228
1237
if dimension != N # chlo.top_k performs the operation along the last dimension
1229
1238
pdims = collect (Int64, 1 : N)
1230
1239
pdims[dimension] = N
@@ -1251,13 +1260,41 @@ end
1251
1260
return (; values, indices)
1252
1261
end
1253
1262
1263
+ @noinline function argmax (
1264
+ x:: TracedRArray{T,N} ;
1265
+ dimension:: Integer = N,
1266
+ location= mlir_stacktrace (" argmax" , @__FILE__ , @__LINE__ ),
1267
+ ) where {T,N}
1268
+ values, indices = reduce (
1269
+ TracedRArray[
1270
+ x, iota (Int64, collect (Int64, size (x)); iota_dimension= dimension, location)
1271
+ ],
1272
+ TracedRNumber[
1273
+ Reactant. TracedUtils. promote_to (TracedRNumber{T}, typemin (T)),
1274
+ Reactant. TracedUtils. promote_to (TracedRNumber{Int64}, - 1 ),
1275
+ ],
1276
+ [dimension],
1277
+ function (a₁, i₁, a₂, i₂)
1278
+ cond = a₁ ≥ a₂
1279
+ return ifelse (cond, a₁, a₂), ifelse (cond, i₁, i₂)
1280
+ end ;
1281
+ location,
1282
+ )
1283
+ new_shape = collect (Int64, size (x))
1284
+ new_shape[dimension] = 1
1285
+ return (
1286
+ Ops. reshape (values, new_shape; location), Ops. reshape (indices, new_shape; location)
1287
+ )
1288
+ end
1289
+
1254
1290
@noinline function iota (
1255
1291
T:: Type ,
1256
1292
shape:: Vector{Int} ;
1257
1293
iota_dimension,
1258
1294
location= mlir_stacktrace (" iota" , @__FILE__ , @__LINE__ ),
1259
1295
)
1260
1296
N = length (shape)
1297
+ @assert 0 < iota_dimension <= N
1261
1298
output = mlir_type (TracedRArray{T,N}, shape)
1262
1299
iota_dimension = MLIR. IR. Attribute (iota_dimension - 1 )
1263
1300
res = MLIR. IR. result (stablehlo. iota (; output, iota_dimension, location))
@@ -2631,24 +2668,30 @@ Produces a [`Reactant.MLIR.Dialects.sdy.sharding_constraint`](@ref) operation wi
2631
2668
end
2632
2669
end
2633
2670
2634
- function _construct_reduce_function (f:: F , :: Type{T} ) where {F,T}
2671
+ function _construct_reduce_function (f:: F , Ts:: Type... ) where {F}
2672
+ inputs_1 = [Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ) for T in Ts]
2673
+ inputs_2 = [Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ) for T in Ts]
2635
2674
func =
2636
2675
Reactant. TracedUtils. make_mlir_fn (
2637
2676
f,
2638
- (
2639
- Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ),
2640
- Reactant. TracedUtils. promote_to (TracedRNumber{T}, 0 ),
2641
- ),
2677
+ (inputs_1... , inputs_2... ),
2642
2678
(),
2643
2679
" reduce_fn" * string (f),
2644
2680
false ;
2645
2681
args_in_result= :none ,
2646
2682
return_dialect= :stablehlo ,
2647
2683
). f
2684
+
2648
2685
@assert MLIR. IR. nregions (func) == 1
2649
2686
ftype_attr = MLIR. IR. attr (func, " function_type" )
2650
2687
ftype = MLIR. IR. Type (ftype_attr)
2651
- @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType (Int[], MLIR. IR. Type (T)) " $(fn) return type is not of tensor<$(T) >"
2688
+
2689
+ @assert MLIR. IR. nresults (ftype) == length (Ts)
2690
+ for i in 1 : MLIR. IR. nresults (ftype)
2691
+ tType = MLIR. IR. TensorType (Int[], MLIR. IR. Type (Ts[i]))
2692
+ @assert MLIR. IR. result (ftype, i) == tType " $(f) return type $(i) is not of \
2693
+ tensor<$(Ts[i]) >"
2694
+ end
2652
2695
2653
2696
fn = MLIR. IR. Region ()
2654
2697
MLIR. API. mlirRegionTakeBody (fn, MLIR. IR. region (func, 1 ))
@@ -2703,23 +2746,41 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
2703
2746
x:: TracedRArray{T} ,
2704
2747
init_values:: TracedRNumber{T} ,
2705
2748
dimensions:: Vector{Int} ,
2706
- fn:: Function ,
2749
+ fn:: F ;
2707
2750
location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2708
- ) where {T}
2709
- reduced_shape = Tuple (deleteat! (collect (Int64, size (x)), dimensions))
2751
+ ) where {T,F}
2752
+ return only (reduce ([x], [init_values], dimensions, fn; location))
2753
+ end
2710
2754
2711
- res = MLIR. IR. result (
2712
- stablehlo. reduce (
2713
- [x. mlir_data],
2714
- [init_values. mlir_data];
2715
- result_0= [mlir_type (TracedRArray{T,length (reduced_shape)}, reduced_shape)],
2716
- dimensions= MLIR. IR. Attribute (dimensions .- 1 ),
2717
- body= _construct_reduce_function (fn, T),
2718
- location= location,
2719
- ),
2755
+ @noinline function reduce (
2756
+ xs:: Vector{<:TracedRArray} ,
2757
+ init_values:: Vector{<:TracedRNumber} ,
2758
+ dimensions:: Vector{Int} ,
2759
+ fn:: F ;
2760
+ location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2761
+ ) where {F}
2762
+ @assert allequal (size .(xs)) " All input arrays must have the same size."
2763
+
2764
+ reduced_shape = Tuple (deleteat! (collect (Int64, size (xs[1 ])), dimensions))
2765
+
2766
+ op = stablehlo. reduce (
2767
+ [x. mlir_data for x in xs],
2768
+ [init_value. mlir_data for init_value in init_values];
2769
+ result_0= [
2770
+ mlir_type (
2771
+ TracedRArray{unwrapped_eltype (x),length (reduced_shape)}, reduced_shape
2772
+ ) for x in xs
2773
+ ],
2774
+ dimensions= MLIR. IR. Attribute (dimensions .- 1 ),
2775
+ body= _construct_reduce_function (fn, [unwrapped_eltype (x) for x in xs]. .. ),
2776
+ location,
2720
2777
)
2721
2778
2722
- return TracedRArray {T,length(reduced_shape)} ((), res, reduced_shape)
2779
+ return [
2780
+ TracedRArray {unwrapped_eltype(xs[i]),length(reduced_shape)} (
2781
+ (), MLIR. IR. result (op, i), reduced_shape
2782
+ ) for i in 1 : MLIR. IR. nresults (op)
2783
+ ]
2723
2784
end
2724
2785
2725
2786
@noinline function dynamic_update_slice (
0 commit comments