Skip to content

Commit 9d1fcc4

Browse files
authored
[tritonintelgpu-materialize-block-pointer]: Failure to retrieve make_tensor_ptr operation for load fed by tt.advance in a scf.for loop.loop. (#4344)
The materialize block ptr transformation trips on an assertion because it cannot locate the `make_tensor_ptr` operation that crated the block ptr used by a `tt.load` via a `tt.select` -> `tt.advance` -> `tt.load` use-def chain. Fixes issue #4342. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 33f5599 commit 9d1fcc4

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

test/TritonIntelGPU/materialize-block-pointer.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,32 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
106106
tt.return
107107
}
108108
}
109+
110+
// -----
111+
112+
// COM: Ensure load is annotated when its base ptr is the result of a `tt.advance` operation fed by a select operation.
113+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
114+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.support_sg_2d_block} {
115+
tt.func public @load_fed_by_advance_op_in_loop(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) {
116+
%c8_i32 = arith.constant 8 : i32
117+
%c0_i32 = arith.constant 0 : i32
118+
%c128_i32 = arith.constant 128 : i32
119+
%c384_i32 = arith.constant 384 : i32
120+
%c1_i64 = arith.constant 1 : i64
121+
%0 = tt.get_program_id x : i32
122+
%1 = arith.muli %0, %c8_i32 : i32
123+
%2 = arith.extsi %arg2 : i32 to i64
124+
%3 = arith.extsi %arg1 : i32 to i64
125+
%4 = tt.make_tensor_ptr %arg0, [%3, %2], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #blocked>>
126+
%5 = scf.for %arg3 = %c0_i32 to %arg2 step %c128_i32 iter_args(%arg4 = %4) -> (!tt.ptr<tensor<8x128xf32, #blocked>>) : i32 {
127+
%7 = arith.remsi %arg3, %c384_i32 : i32
128+
%8 = arith.cmpi eq, %7, %c0_i32 : i32
129+
%12 = arith.select %8, %4, %arg4 : !tt.ptr<tensor<8x128xf32, #blocked>>
130+
%14 = tt.advance %12, [%1, %arg3] : <tensor<8x128xf32, #blocked>>
131+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"}
132+
%15 = tt.load %14 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<8x128xf32, #blocked>>
133+
scf.yield %12 : !tt.ptr<tensor<8x128xf32, #blocked>>
134+
}
135+
tt.return
136+
}
137+
}

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,14 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5555
"Expected 'loadOp' to load a tensor value.");
5656

5757
// Find the make tensor ptr operation that created the base ptr.
58-
tt::MakeTensorPtrOp makeTensorPtrOp = tt::getMakeTensorPtrOp(ptr);
59-
if (!makeTensorPtrOp) {
58+
std::optional<tt::MakeTensorPtrOp> defOp =
59+
tt::intel::findDefiningMakeTensorPtrOp(ptr);
60+
if (!defOp) {
6061
LDBG("Could not find make tensor ptr op for: " << loadOp);
6162
return;
6263
}
64+
65+
tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
6366
LDBG("Make tensor ptr op: " << makeTensorPtrOp);
6467

6568
Operation::operand_range shape = makeTensorPtrOp.getShape();
@@ -290,9 +293,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass
290293

291294
// Find the make tensor ptr operation that created the base ptr for the load
292295
// operation.
293-
tt::MakeTensorPtrOp makeTensorPtrOp = tt::getMakeTensorPtrOp(ptr);
294-
assert(makeTensorPtrOp && "Expected a make tensor ptr op.");
296+
std::optional<tt::MakeTensorPtrOp> defOp =
297+
tt::intel::findDefiningMakeTensorPtrOp(ptr);
298+
if (!defOp)
299+
return false;
295300

301+
tt::MakeTensorPtrOp makeTensorPtrOp = *defOp;
296302
Operation::operand_range shape = makeTensorPtrOp.getShape();
297303
if (shape.size() == 1)
298304
return false;

0 commit comments

Comments
 (0)