Skip to content

Commit ce49a59

Browse files
authored
[tritonintelgpu-remove-layout-conversions]: Failure to find make_tensor_ptr operation for tt.store within while loop. (#4330)
This PR introduces a centralized helper to trace back to the defining MakeTensorPtrOp for a tensor pointer, and updates two GPU transformation passes to use it. It fixes issue #4336. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 23efa0d commit ce49a59

File tree

5 files changed

+109
-68
lines changed

5 files changed

+109
-68
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2564,3 +2564,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
25642564
tt.return
25652565
}
25662566
}
2567+
2568+
// -----
2569+
2570+
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2571+
// CHECK-NOT: #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
2572+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
2573+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
2574+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32, ttig.support_sg_2d_block} {
2575+
// CHECK-LABEL: while_using_advanced_ptr
2576+
tt.func public @while_using_advanced_ptr(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
2577+
%c1_i64 = arith.constant 1 : i64
2578+
%cst = arith.constant dense<5.000000e+00> : tensor<8x128xf32, #blocked>
2579+
%c0_i32 = arith.constant 0 : i32
2580+
%0 = tt.get_program_id x : i32
2581+
%2 = arith.extsi %arg2 : i32 to i64
2582+
%3 = arith.extsi %arg1 : i32 to i64
2583+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #[[$BLOCKED]]>>
2584+
%4 = tt.make_tensor_ptr %arg0, [%3, %2], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<8x128xf32, #blocked1>>
2585+
%8 = arith.cmpi eq, %0, %c0_i32 : i32
2586+
%6 = scf.while (%arg3 = %4) : (!tt.ptr<tensor<8x128xf32, #blocked1>>) -> (!tt.ptr<tensor<8x128xf32, #blocked1>>) {
2587+
%7 = arith.cmpi slt, %0, %c0_i32 : i32
2588+
scf.condition(%7) %arg3 : !tt.ptr<tensor<8x128xf32, #blocked1>>
2589+
} do {
2590+
^bb0(%arg3: !tt.ptr<tensor<8x128xf32, #blocked1>>):
2591+
// CHECK-NOT: ttg.convert_layout
2592+
// CHECK: [[SEL:%.*]] = arith.select {{.*}} : !tt.ptr<tensor<8x128xf32, #[[$BLOCKED]]>>
2593+
// CHECK: [[PTR1:%.*]] = tt.advance [[SEL]], {{.*}} : <tensor<8x128xf32, #[[$BLOCKED]]>>
2594+
// CHECK: tt.store [[PTR1]], {{.*}} : !tt.ptr<tensor<8x128xf32, #[[$BLOCKED]]>>
2595+
%12 = arith.select %8, %4, %arg3 : !tt.ptr<tensor<8x128xf32, #blocked1>>
2596+
%14 = tt.advance %12, [%0, %0] : <tensor<8x128xf32, #blocked1>>
2597+
%18 = ttg.convert_layout %cst : tensor<8x128xf32, #blocked> -> tensor<8x128xf32, #blocked1>
2598+
tt.store %14, %18 : !tt.ptr<tensor<8x128xf32, #blocked1>>
2599+
scf.yield %12 : !tt.ptr<tensor<8x128xf32, #blocked1>>
2600+
}
2601+
tt.return
2602+
}
2603+
}

third_party/intel/include/Utils/Utility.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
#include "mlir/IR/Builders.h"
55
#include "mlir/IR/Value.h"
66

7+
namespace mlir::triton {
8+
class MakeTensorPtrOp;
9+
}
710
namespace mlir::triton::intel {
811

912
// Lookup for a integer constant with the given value and bitwidth in the
@@ -12,6 +15,10 @@ namespace mlir::triton::intel {
1215
Value findOrCreateIntConstant(Location loc, int val, unsigned bitWidth,
1316
OpBuilder &builder);
1417

18+
// Find the defining makeTensorPtrOp operation of the given value.
19+
std::optional<mlir::triton::MakeTensorPtrOp>
20+
findDefiningMakeTensorPtrOp(Value val);
21+
1522
// This function folds the `op` operation and returns the constant value if it
1623
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`.
1724
std::optional<int64_t> getFoldedConstantValue(Operation *op);

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 6 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
#include "intel/include/Dialect/TritonIntelGPU/IR/Utils.h"
33
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
44
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
5-
#include "mlir/Dialect/Arith/IR/Arith.h"
5+
#include "intel/include/Utils/Utility.h"
66
#include "mlir/Dialect/SCF/IR/SCF.h"
77
#include "mlir/IR/Operation.h"
88
#include "mlir/IR/Value.h"
99
#include "mlir/IR/Verifier.h"
1010
#include "mlir/Support/LLVM.h"
11+
#include "triton/Dialect/Triton/IR/Types.h"
1112
#include "triton/Dialect/Triton/IR/Utility.h"
1213
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1314
#include "triton/Tools/StrUtil.h"
@@ -126,62 +127,6 @@ struct CoalescePass
126127
tensorType.getElementType(), encoding);
127128
}
128129

129-
// Find the defining makeTensorPtrOp operation of the given value.
130-
static std::optional<tt::MakeTensorPtrOp>
131-
findDefiningMakeTensorPtrOp(Value val) {
132-
LLVM_DEBUG({
133-
llvm::dbgs() << "[" DEBUG_TYPE "]: \t"
134-
<< "Attempting to find `makeTensorPtrOp` defining: " << val
135-
<< "\n";
136-
});
137-
138-
if (auto arg = dyn_cast<BlockArgument>(val)) {
139-
Operation *parentOp = arg.getParentBlock()->getParentOp();
140-
141-
Value loopArg;
142-
if (auto forOp = dyn_cast<scf::ForOp>(parentOp))
143-
loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1];
144-
else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp))
145-
loopArg = whileOp.getInits()[arg.getArgNumber()];
146-
else
147-
llvm_unreachable("Unexpected parent operator");
148-
149-
return findDefiningMakeTensorPtrOp(loopArg);
150-
}
151-
152-
if (auto advanceOp = val.getDefiningOp<tt::AdvanceOp>())
153-
return findDefiningMakeTensorPtrOp(advanceOp.getPtr());
154-
if (auto makePtrOp = val.getDefiningOp<tt::MakeTensorPtrOp>())
155-
return makePtrOp;
156-
if (auto opRes = dyn_cast<OpResult>(val)) {
157-
Operation *defOp = opRes.getOwner();
158-
if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
159-
Value val = forOp.getYieldedValues()[opRes.getResultNumber()];
160-
return findDefiningMakeTensorPtrOp(val);
161-
}
162-
if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
163-
Value val = whileOp.getYieldedValues()[opRes.getResultNumber()];
164-
return findDefiningMakeTensorPtrOp(val);
165-
}
166-
if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
167-
// Give up if the 2 possible definitions aren't the same.
168-
Value trueVal = selectOp.getTrueValue(),
169-
falseVal = selectOp.getFalseValue();
170-
std::optional<tt::MakeTensorPtrOp> trueDef =
171-
findDefiningMakeTensorPtrOp(trueVal);
172-
std::optional<tt::MakeTensorPtrOp> falseDef =
173-
findDefiningMakeTensorPtrOp(falseVal);
174-
if (!trueDef || !falseDef || *trueDef != *falseDef)
175-
return std::nullopt;
176-
return trueDef;
177-
}
178-
179-
assert(false && "unhandled operation");
180-
}
181-
182-
return std::nullopt;
183-
}
184-
185130
static bool filterUser(Operation *op) {
186131
// Yield operations trigger updating the layout of the containing loop
187132
// results, don't skip them.
@@ -446,10 +391,11 @@ struct CoalescePass
446391
newArgs.push_back(builder.create<ttg::ConvertLayoutOp>(
447392
op->getLoc(), newType, operand));
448393
} else {
449-
assert(isa<tt::PointerType>(operand.getType()) &&
394+
assert(tt::isTensorPointerType(operand.getType()) &&
450395
"Expecting operand to have blocked pointer type");
451-
auto defOp = findDefiningMakeTensorPtrOp(operand);
452-
assert(defOp && "Expected a make_tensor_ptr operation");
396+
std::optional<tt::MakeTensorPtrOp> defOp =
397+
triton::intel::findDefiningMakeTensorPtrOp(operand);
398+
assert(defOp && "Expecting a MakeTensorPtr operation");
453399
LLVM_DEBUG({
454400
llvm::dbgs() << "[" DEBUG_TYPE "]: Found definition: " << defOp
455401
<< "\n";

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1414
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h"
1515
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
16+
#include "intel/include/Utils/Utility.h"
1617

1718
#include "triton/Analysis/Utility.h"
19+
#include "triton/Dialect/Triton/IR/Dialect.h"
1820
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
1921
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
2022
#include <deque>
@@ -731,12 +733,12 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
731733
return false;
732734

733735
// Locate the operation that created the block pointer.
734-
Operation *defOp = ptr.getDefiningOp();
735-
while (auto advanceOp = dyn_cast<AdvanceOp>(defOp))
736-
defOp = advanceOp.getPtr().getDefiningOp();
737-
assert(isa<MakeTensorPtrOp>(defOp) &&
738-
"MakeTensorPtrOp should be the only op that creates a tensor pointer");
739-
auto makeTensorPtrOp = cast<MakeTensorPtrOp>(defOp);
736+
std::optional<triton::MakeTensorPtrOp> defOp =
737+
triton::intel::findDefiningMakeTensorPtrOp(ptr);
738+
if (!defOp)
739+
return false;
740+
741+
triton::MakeTensorPtrOp makeTensorPtrOp = *defOp;
740742

741743
// DPAS encoding have to be propagated if conversion from a DPAS layout to
742744
// another layout has been done before.
@@ -1585,8 +1587,9 @@ void hoistConvert(ModuleOp module) {
15851587
}
15861588

15871589
class TritonIntelGPURemoveLayoutConversionsPass
1588-
: public intel::impl::TritonIntelGPURemoveLayoutConversionsBase<
1589-
TritonIntelGPURemoveLayoutConversionsPass> {
1590+
: public triton::gpu::intel::impl::
1591+
TritonIntelGPURemoveLayoutConversionsBase<
1592+
TritonIntelGPURemoveLayoutConversionsPass> {
15901593
public:
15911594
// Cleanup convert ops.
15921595
void cleanupConvertOps() {

third_party/intel/lib/Utils/Utility.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,54 @@ Value findOrCreateIntConstant(Location loc, int val, unsigned bitWidth,
3333
: builder.createOrFold<arith::ConstantIntOp>(loc, val, bitWidth);
3434
}
3535

36+
std::optional<tt::MakeTensorPtrOp> findDefiningMakeTensorPtrOp(Value val) {
37+
if (auto arg = dyn_cast<BlockArgument>(val)) {
38+
Operation *parentOp = arg.getParentBlock()->getParentOp();
39+
40+
Value loopArg;
41+
if (auto forOp = dyn_cast<scf::ForOp>(parentOp))
42+
loopArg = forOp.getInitArgs()[arg.getArgNumber() - 1];
43+
else if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp))
44+
loopArg = whileOp.getInits()[arg.getArgNumber()];
45+
else
46+
llvm_unreachable("Unexpected parent operator");
47+
48+
return findDefiningMakeTensorPtrOp(loopArg);
49+
}
50+
51+
if (auto advanceOp = val.getDefiningOp<tt::AdvanceOp>())
52+
return findDefiningMakeTensorPtrOp(advanceOp.getPtr());
53+
if (auto makePtrOp = val.getDefiningOp<tt::MakeTensorPtrOp>())
54+
return makePtrOp;
55+
if (auto opRes = dyn_cast<OpResult>(val)) {
56+
Operation *defOp = opRes.getOwner();
57+
if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
58+
Value val = forOp.getYieldedValues()[opRes.getResultNumber()];
59+
return findDefiningMakeTensorPtrOp(val);
60+
}
61+
if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
62+
Value val = whileOp.getYieldedValues()[opRes.getResultNumber()];
63+
return findDefiningMakeTensorPtrOp(val);
64+
}
65+
if (auto selectOp = dyn_cast<arith::SelectOp>(defOp)) {
66+
// Give up if the 2 possible definitions aren't the same.
67+
Value trueVal = selectOp.getTrueValue(),
68+
falseVal = selectOp.getFalseValue();
69+
std::optional<tt::MakeTensorPtrOp> trueDef =
70+
findDefiningMakeTensorPtrOp(trueVal);
71+
std::optional<tt::MakeTensorPtrOp> falseDef =
72+
findDefiningMakeTensorPtrOp(falseVal);
73+
if (!trueDef || !falseDef || *trueDef != *falseDef)
74+
return std::nullopt;
75+
return trueDef;
76+
}
77+
78+
assert(false && "unhandled operation");
79+
}
80+
81+
return std::nullopt;
82+
}
83+
3684
std::optional<int64_t> getFoldedConstantValue(Operation *op) {
3785
SmallVector<OpFoldResult> results;
3886
if (failed(op->fold(results)))

0 commit comments

Comments
 (0)