Skip to content

Commit ef8ded8

Browse files
Generalize Intel coalescing pass to handle users of scf.for with coalesced load (#2856)
Ensure that loads of a `scf.for` yielded value with block ptr type can be coalesced --------- Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 61e33eb commit ef8ded8

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

test/TritonIntelGPU/coalesce.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,4 +382,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, "ttg.th
382382
}) : (tensor<32x128xf32, #blocked>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #blocked}>>
383383
tt.return
384384
}
385+
386+
// CHECK: @issue_2762
387+
tt.func public @issue_2762(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
388+
%c128_i32 = arith.constant 128 : i32
389+
%c0_i32 = arith.constant 0 : i32
390+
%c262144_i64 = arith.constant 262144 : i64
391+
%c1_i64 = arith.constant 1 : i64
392+
%c512_i64 = arith.constant 512 : i64
393+
%c32_i32 = arith.constant 32 : i32
394+
%c512_i32 = arith.constant 512 : i32
395+
%0 = tt.get_program_id x : i32
396+
%1 = arith.muli %0, %c32_i32 : i32
397+
%4 = arith.divsi %1, %c512_i32 : i32
398+
%5 = arith.remsi %1, %c512_i32 : i32
399+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg0, {{.*}} : <tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
400+
%y = tt.make_tensor_ptr %arg0, [%c512_i64, %c512_i64, %c512_i64], [%c1_i64, %c512_i64, %c262144_i64], [%4, %5, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x128xf32, #blocked1>>
401+
// CHECK: [[RES:%.*]] = scf.for {{.*}} iter_args([[ARG1:%.*]] = [[PTR1]]) -> (!tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>)
402+
%8:1 = scf.for %arg5 = %c0_i32 to %c512_i32 step %c128_i32 iter_args(%arg7 = %y) -> (!tt.ptr<tensor<1x32x128xf32, #blocked1>>) : i32 {
403+
// CHECK: scf.yield [[ARG1]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
404+
scf.yield %arg7 : !tt.ptr<tensor<1x32x128xf32, #blocked1>>
405+
}
406+
// CHECK: [[LOAD_RES:%.*]] = tt.load [[RES]] : !tt.ptr<tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]>>
407+
// CHECK: ttg.convert_layout [[LOAD_RES]] : tensor<1x32x128xf32, [[BLOCKED_LAYOUT1]]> -> tensor<1x32x128xf32, [[BLOCKED_LAYOUT2]]>
408+
%res = tt.load %8#0 : !tt.ptr<tensor<1x32x128xf32, #blocked1>>
409+
tt.return
410+
}
385411
}

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ struct CoalescePass
122122
// Find the defining makeTensorPtrOp operation of the given value.
123123
static std::optional<tt::MakeTensorPtrOp>
124124
findDefiningMakeTensorPtrOp(Value val) {
125+
LDBG("Attempting to find `makeTensorPtrOp` defining: " << val);
126+
125127
if (auto arg = dyn_cast<BlockArgument>(val)) {
126128
Operation *parentOp = val.getParentBlock()->getParentOp();
127129
assert(isa<scf::ForOp>(parentOp) && "Expected a scf::ForOp");
@@ -134,6 +136,14 @@ struct CoalescePass
134136
return findDefiningMakeTensorPtrOp(advanceOp.getPtr());
135137
if (auto makePtrOp = val.getDefiningOp<tt::MakeTensorPtrOp>())
136138
return makePtrOp;
139+
if (auto opRes = dyn_cast<OpResult>(val)) {
140+
Operation *defOp = opRes.getOwner();
141+
if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
142+
Value val = forOp.getYieldedValues()[opRes.getResultNumber()];
143+
return findDefiningMakeTensorPtrOp(val);
144+
}
145+
assert(false && "unhandled operation");
146+
}
137147

138148
return std::nullopt;
139149
}
@@ -369,12 +379,14 @@ struct CoalescePass
369379
});
370380

371381
LLVM_DEBUG({
372-
DBGS() << "\nlayoutMap:\n";
382+
DBGS() << "layoutMap:\n";
383+
if (layoutMap.empty())
384+
DBGS() << "\t<empty>";
373385
for (auto [op, encoding] : layoutMap) {
374-
DBGS() << "op: " << *op << "\n";
375-
DBGS() << "encoding: " << encoding << "\n\n";
386+
DBGS() << "\top: " << *op << "\n";
387+
DBGS() << "\tencoding: " << encoding << "\n";
376388
}
377-
llvm::errs() << "\n\n";
389+
llvm::errs() << "\n";
378390
});
379391

380392
// For each memory op that has a layout L1:

0 commit comments

Comments
 (0)