Skip to content

Commit 734e33c

Browse files
authored
Use block load attribute to remove duplicate logic from MaterializeBlockPointer pass (#2420)
The logic in `shouldRemove` in the `RewriteTensorPointer` pass duplicates the same logic in `MaterializeBlockPointer`: https://github.com/intel/intel-xpu-backend-for-triton/blob/main/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp#L50 This duplication is necessary for the Matrix transpose multiplication case because the block pointer is defined outside a `scf.for` loop, but the load is inside the loop. The previous logic in `RewriteTensorPointer` could not "see" into the `scf.for` loop block and decided to remove the tensor pointer even though its result was used by a block load. This commit changes the algorithm: First, walk the tree and look for MakeTensorPtr ops. For each MakeTensorPtr op, we do a search to find load/store users of the op. If we have a store associated with DPAS layout, or a block load, then we do not mark the MakeTensorPtr op for removal. Otherwise, we mark it for removal. Next, we make a pass through all the ops again and make sure we removal all MakeTensorPtr-related ops for each MakeTensorPtr marked for removal (tt.advance, rewrite the loads, etc).   Close #2380
1 parent a2a3100 commit 734e33c

File tree

4 files changed

+140
-81
lines changed

4 files changed

+140
-81
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4286,8 +4286,9 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con
42864286
actual = torch.zeros(expected.shape, dtype=torch.int32, device=device)
42874287

42884288
k = kernel[(1, )](input, actual, shape[0], shape[1])
4289-
assert k.asm['ttgir'].count(
4290-
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"
4289+
if not is_xpu():
4290+
assert k.asm['ttgir'].count(
4291+
'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization"
42914292

42924293
np.testing.assert_equal(to_numpy(expected), to_numpy(actual))
42934294

test/TritonIntelGPU/rewrite-tensor-pointer.mlir

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
4444
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
4545
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
4646
%23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
47-
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
48-
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
49-
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
50-
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
47+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
48+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
49+
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
50+
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
5151
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
5252
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
5353
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
@@ -59,8 +59,8 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
5959
%25 = arith.extsi %arg9 : i32 to i64
6060
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
6161
%26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #dpas>>
62-
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
63-
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #dpas>>
62+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
63+
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x256xf32, #dpas>>
6464
%28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas>
6565
%29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
6666

@@ -125,10 +125,10 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
125125
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
126126
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
127127
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
128-
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
129-
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
130-
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
131-
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
128+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
129+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
130+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #dot0>>
131+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #dot1>>
132132
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
133133
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
134134
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
@@ -335,3 +335,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
335335
tt.return
336336
}
337337
}
338+
339+
// -----
340+
341+
// COM: Case 5:
342+
// COM: Check that a make tensor ptr with no loads is properly removed
343+
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
344+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
345+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
346+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
347+
// CHECK: @matmul_kernel_with_block_pointers
348+
%c4_i32 = arith.constant 4 : i32
349+
%c256_i32 = arith.constant 256 : i32
350+
%c1024_i64 = arith.constant 1024 : i64
351+
%c5120_i64 = arith.constant 5120 : i64
352+
%c1_i64 = arith.constant 1 : i64
353+
%c0_i32 = arith.constant 0 : i32
354+
%c4096_i64 = arith.constant 4096 : i64
355+
%c32_i32 = arith.constant 32 : i32
356+
%c64_i32 = arith.constant 64 : i32
357+
%c5120_i32 = arith.constant 5120 : i32
358+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
359+
%0 = tt.get_program_id x : i32
360+
%1 = arith.divsi %0, %c64_i32 : i32
361+
%2 = arith.muli %1, %c4_i32 : i32
362+
%3 = arith.subi %c4_i32, %2 : i32
363+
%4 = arith.minsi %3, %c4_i32 : i32
364+
%5 = arith.remsi %0, %4 : i32
365+
%6 = arith.addi %2, %5 : i32
366+
%7 = arith.remsi %0, %c64_i32 : i32
367+
%8 = arith.divsi %7, %4 : i32
368+
%9 = arith.muli %6, %c256_i32 : i32
369+
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
370+
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
371+
%11 = arith.muli %8, %c256_i32 : i32
372+
// CHECK-NOT: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
373+
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
374+
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
375+
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
376+
// CHECK-NOT: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
377+
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
378+
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
379+
scf.yield %arg4, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
380+
}
381+
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
382+
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
383+
tt.return
384+
}
385+
}

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ struct TritonIntelGPUMaterializeBlockPointerPass
7171
return;
7272
}
7373

74+
if (fastChangeDim == rank - 2 &&
75+
tensorType.getElementTypeBitWidth() == 8) {
76+
// TODO: column major layout w/ fp8 has performance regression
77+
return;
78+
}
79+
7480
if (fastChangeDim >= (rank - 2)) {
7581
// HW 2D block read instruction only supports contiguous access.
7682
Value fastChangeStride = strides[fastChangeDim];

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 73 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ namespace {
3333
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
3434
/// - its pitch is not divisible by Qword bitwidth
3535
/// - it is not contiguous in memory
36-
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
36+
bool shouldRemove(tt::MakeTensorPtrOp &op, const bool isUsedByStoreOp,
37+
const bool isUsedByBlockLoadOp) {
3738
LDBG("Considering removal of: " << op);
3839
if (!op->getParentOfType<ModuleOp>()->hasAttr(
3940
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
@@ -45,61 +46,19 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
4546
LDBG("Op ptr type: " << ptrType);
4647
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
4748
LDBG("Op tensor type: " << tensorType);
48-
49-
if (!ttgi::hasDotDpasEncoding(tensorType) &&
50-
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
51-
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
52-
"by load or store op with DPAS layout");
53-
return true;
54-
}
55-
56-
TypedValue<triton::PointerType> base = op.getBase();
57-
Operation::operand_range shape = op.getShape();
58-
unsigned rank = shape.size();
59-
assert(rank > 1 && "Expecting tensor with rank > 1");
60-
Operation::operand_range strides = op.getStrides();
61-
Operation::operand_range offsets = op.getOffsets();
62-
ArrayRef<int32_t> order = op.getOrder();
63-
ArrayRef<int64_t> tensorShape = tensorType.getShape();
64-
65-
int fastChangeDim = -1;
66-
for (size_t i = 0; i < strides.size(); ++i) {
67-
if (ttgi::isConstant(strides[i], 1)) {
68-
fastChangeDim = i;
69-
break;
70-
}
71-
}
72-
73-
LDBG("fastChangeDim: " << fastChangeDim);
74-
if (fastChangeDim < 0) {
75-
LDBG("Marked for removal: fast changing dimension not found");
76-
return true;
77-
}
78-
79-
LDBG("Tensor type element type bit width: "
80-
<< tensorType.getElementTypeBitWidth());
81-
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
82-
// TODO: column major layout w/ fp8 has performance regression
83-
LDBG("Marked for removal: column major layout with fp8 element type");
84-
return true;
85-
}
86-
87-
// HW 2D block read instruction has restriction on pitch divisibility
88-
if (fastChangeDim >= (rank - 2)) {
89-
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
90-
LDBG("Pitch: " << pitch);
91-
// Across Intel platforms, the strictest pitch restriction is to be a
92-
// multiple of OWord(128 bits).
93-
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
94-
LDBG("Marked for removal: cannot use block read/write instructions");
95-
return true;
96-
}
97-
49+
LDBG("Used by store op? " << isUsedByStoreOp);
50+
LDBG("Used by block load op? " << isUsedByBlockLoadOp);
51+
52+
LDBG("hasDpasEncoding: " << ttgi::hasDpasEncoding(tensorType));
53+
if (isUsedByBlockLoadOp ||
54+
(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) {
55+
LDBG("Tensor has DPAS layout or is used by load/store op with DPAS layout, "
56+
"skipping removal");
9857
return false;
9958
}
10059

101-
LDBG("Marked for removal: fall-trough");
102-
60+
LDBG("Marked for removal: make tensor ptr op is not used by block load op or "
61+
"by store op with DPAS layout");
10362
return true;
10463
}
10564

@@ -715,28 +674,73 @@ class TritonIntelGPURewriteTensorPointerPass
715674
void runOnOperation() override {
716675
ModuleOp mod = getOperation();
717676

718-
auto usedByLoadOrStoreOp = [](Value val) {
719-
return llvm::any_of(val.getUsers(), [](Operation *user) {
720-
return isa<tt::LoadOp, tt::StoreOp>(user);
721-
});
722-
};
677+
DenseSet<Operation *> tensorPointersToRemove;
678+
mod.walk([&](tt::MakeTensorPtrOp makeTensorPtrOp) {
679+
tensorPointersToRemove.insert(makeTensorPtrOp);
680+
DenseSet<Operation *> workingSet;
723681

724-
auto markTensorPointerForRemoval =
725-
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
726-
if (tt::isTensorPointerType(val.getType())) {
727-
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
728-
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
729-
valueToRemove.insert(val);
682+
LDBG("Considering: " << makeTensorPtrOp);
683+
Value result = makeTensorPtrOp.getResult();
684+
for (auto user : result.getUsers()) {
685+
workingSet.insert(user);
686+
}
687+
while (!workingSet.empty()) {
688+
auto crtOpItr = workingSet.begin();
689+
auto crtOp = *crtOpItr;
690+
LDBG("Processing op: " << *crtOp);
691+
if (isa<tt::LoadOp, tt::StoreOp>(crtOp)) {
692+
if (!shouldRemove(
693+
makeTensorPtrOp,
694+
/*isUsedByStoreOp=*/isa<tt::StoreOp>(crtOp),
695+
/*isBlockLoad=*/
696+
isa<tt::LoadOp>(crtOp) &&
697+
crtOp->hasAttr(
698+
ttgi::TritonIntelGPUDialect::getBlockIOAttrName()))) {
699+
tensorPointersToRemove.erase(makeTensorPtrOp);
700+
return WalkResult::advance();
730701
}
731-
};
702+
} else if (auto forOp = dyn_cast<scf::ForOp>(crtOp)) {
703+
for (auto [arg, blockArg] :
704+
llvm::zip(forOp.getInitArgs(),
705+
forOp.getBody()->getArguments().drop_front(
706+
forOp.getNumInductionVars()))) {
707+
if (arg == makeTensorPtrOp) {
708+
// add users of block arg
709+
for (auto user : blockArg.getUsers()) {
710+
workingSet.insert(user);
711+
}
712+
}
713+
}
714+
} else if (crtOp->getNumResults() > 0) {
715+
// TODO: should we handle more than one result?
716+
auto crtOpResult = crtOp->getResult(0);
717+
LDBG("Not a load store and not a loop, adding users to working "
718+
"set.");
719+
for (auto user : crtOpResult.getUsers()) {
720+
workingSet.insert(user);
721+
}
722+
}
723+
workingSet.erase(crtOpItr);
724+
}
725+
return WalkResult::advance();
726+
});
727+
728+
auto markTensorPointerForRemoval = [this,
729+
&tensorPointersToRemove](Value val) {
730+
if (tt::isTensorPointerType(val.getType())) {
731+
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
732+
if (tensorPointersToRemove.count(makeTensorPtrOp)) {
733+
valueToRemove.insert(val);
734+
}
735+
}
736+
};
732737

733738
mod.walk([&](Operation *op) {
734739
if (isa<tt::MakeTensorPtrOp>(op)) {
735740
Value result = op->getResult(0);
736-
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
741+
markTensorPointerForRemoval(result);
737742
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
738-
markTensorPointerForRemoval(op->getOperand(0),
739-
isa<tt::LoadOp, tt::StoreOp>(op));
743+
markTensorPointerForRemoval(op->getOperand(0));
740744
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
741745
for (auto arg : forOp.getInitArgs())
742746
markTensorPointerForRemoval(arg);
@@ -752,7 +756,7 @@ class TritonIntelGPURewriteTensorPointerPass
752756
else {
753757
DBGS() << "Values to remove: ";
754758
for (auto val : valueToRemove)
755-
DBGS() << val;
759+
DBGS() << val << "\n";
756760
}
757761
});
758762

0 commit comments

Comments
 (0)