Skip to content

Commit f8ce49d

Browse files
authored
[TritonIntelRemoveMasks] Correctly correlate ForOp results with IfOp results (#4823)
The problem was that the `StoreOp` operation started using the first result of the `ifOp` operation, rather than the third, as it had previously done with the results of the `ForOp` operation. --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 9431243 commit f8ce49d

File tree

2 files changed

+34
-5
lines changed

2 files changed

+34
-5
lines changed

test/Triton/Intel/RemoveMasks/loop-invariant-masks.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,34 @@ module {
107107
// CHECK: tt.return
108108
// CHECK: }
109109
}
110+
111+
// -----
112+
113+
module {
114+
// COM: From Liger-Kernel
115+
// COM: For details: https://github.com/intel/intel-xpu-backend-for-triton/issues/4796
116+
// CHECK-LABEL: _error_repro_kernel
117+
tt.func public @_error_repro_kernel(%input_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %input_row_stride: i32 {tt.divisibility = 16 : i32}, %temp_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %temp_row_stride: i32 {tt.divisibility = 16 : i32}, %output_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n_rows: i32, %n_cols: i32 {tt.divisibility = 16 : i32}) {
118+
%c1_i32 = arith.constant 1 : i32
119+
%c0_i32 = arith.constant 0 : i32
120+
%cst = arith.constant dense<0.000000e+00> : tensor<64xf32>
121+
%col_offsets = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
122+
%mask = tt.splat %n_cols : i32 -> tensor<64xi32>
123+
%mask_0 = arith.cmpi slt, %col_offsets, %mask : tensor<64xi32>
124+
// CHECK: %[[IF:.*]]:3 = scf.if %{{.*}} -> (!tt.ptr<f32>, !tt.ptr<f32>, tensor<64xf32>) {
125+
%output_row:3 = scf.for %_ = %c0_i32 to %n_rows step %c1_i32 iter_args(%input_ptr_1 = %input_ptr, %temp_ptr_2 = %temp_ptr, %output_row_3 = %cst) -> (!tt.ptr<f32>, !tt.ptr<f32>, tensor<64xf32>) : i32 {
126+
%input_row = tt.splat %input_ptr_1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
127+
%input_row_4 = tt.addptr %input_row, %col_offsets : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
128+
%input_row_5 = tt.load %input_row_4, %mask_0, %cst : tensor<64x!tt.ptr<f32>>
129+
%temp_ptr_6 = tt.addptr %temp_ptr_2, %temp_row_stride : !tt.ptr<f32>, i32
130+
%output_row_7 = arith.addf %output_row_3, %input_row_5 : tensor<64xf32>
131+
%input_ptr_8 = tt.addptr %input_ptr_1, %input_row_stride : !tt.ptr<f32>, i32
132+
scf.yield %input_ptr_8, %temp_ptr_6, %output_row_7 : !tt.ptr<f32>, !tt.ptr<f32>, tensor<64xf32>
133+
}
134+
%0 = tt.splat %output_ptr : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
135+
%1 = tt.addptr %0, %col_offsets : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
136+
// CHECK: tt.store %{{.*}}, %[[IF]]#2, %{{.*}} : tensor<64x!tt.ptr<f32>>
137+
tt.store %1, %output_row#2, %mask_0 : tensor<64x!tt.ptr<f32>>
138+
tt.return
139+
}
140+
}

third_party/intel/lib/Dialect/Triton/Transforms/RemoveMasks.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -497,11 +497,9 @@ class LoopVersioner {
497497
}
498498

499499
// Replace the uses of the original loop results.
500-
unsigned idx = 0;
501-
for (Value res : forOp.getResults()) {
502-
if (!res.getUsers().empty())
503-
res.replaceAllUsesWith(ifOp->getResult(idx++));
504-
}
500+
for (const auto &[i, v] : llvm::enumerate(forOp.getResults()))
501+
if (!v.getUsers().empty())
502+
v.replaceAllUsesWith(ifOp->getResult(i));
505503

506504
forOp.erase();
507505
return true;

0 commit comments

Comments
 (0)