Skip to content

Commit 74d5059

Browse files
authored
Fix assertion in findDefiningMakeTensorPtrOp (#4606)
Fixes issue #4605. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 5f4ce67 commit 74d5059

File tree

2 files changed

+51
-18
lines changed

2 files changed

+51
-18
lines changed

test/TritonIntelGPU/dot-operands.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,3 +303,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
303303
// CHECK-LABEL: doNotFuseLoadWithTrans4
304304
// CHECK: tt.trans
305305
}
306+
307+
// -----
308+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
309+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
310+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
311+
// COM: Ensure tt.trans is not fused with tt.load when the load uses a pointer yielded by a function call.
312+
tt.func @func(%cond: i1, %p1: !tt.ptr<tensor<32x64xf16, #linear>>, %p2: !tt.ptr<tensor<32x64xf16, #linear>>) -> !tt.ptr<tensor<32x64xf16, #linear>> attributes {noinline = true} {
313+
%0 = arith.select %cond, %p1, %p2 : i1, !tt.ptr<tensor<32x64xf16, #linear>>
314+
tt.return %0 : !tt.ptr<tensor<32x64xf16, #linear>>
315+
}
316+
tt.func public @doNotFuseLoadWithTrans5(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %cond: i1) {
317+
%c32_i32 = arith.constant 32 : i32
318+
%c0_i32 = arith.constant 0 : i32
319+
%c64_i64 = arith.constant 64 : i64
320+
%c1_i64 = arith.constant 1 : i64
321+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
322+
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
323+
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #linear>>
324+
%24 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
325+
%25 = tt.load %24 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
326+
%29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 {
327+
%adv1 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>>
328+
%adv2 = tt.advance %9, [%c0_i32, %arg13] : <tensor<32x64xf16, #linear>>
329+
%adv3 = tt.call @func(%cond, %adv1, %adv2) : (i1, !tt.ptr<tensor<32x64xf16, #linear>>, !tt.ptr<tensor<32x64xf16, #linear>>) -> !tt.ptr<tensor<32x64xf16, #linear>>
330+
%load1 = tt.load %adv3 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>>
331+
%trans1 = tt.trans %load1 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
332+
%dot1 = tt.dot %25, %trans1, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
333+
%76 = arith.addi %arg13, %c32_i32 : i32
334+
scf.yield %76 : i32
335+
}
336+
tt.return
337+
}
338+
// CHECK-LABEL: doNotFuseLoadWithTrans5
339+
// CHECK: tt.trans
340+
}

third_party/intel/lib/Utils/Utility.cpp

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include "mlir/Dialect/Arith/IR/Arith.h"
33
#include "mlir/Dialect/SCF/IR/SCF.h"
44
#include "mlir/Dialect/UB/IR/UBOps.h"
5+
#include "mlir/Interfaces/LoopLikeInterface.h"
56
#include "mlir/Transforms/DialectConversion.h"
67
#include "triton/Dialect/Triton/IR/Dialect.h"
78
#include <optional>
@@ -52,20 +53,17 @@ std::optional<tt::MakeTensorPtrOp> findDefiningMakeTensorPtrOp(Value val) {
5253

5354
if (auto poisonOp = val.getDefiningOp<ub::PoisonOp>())
5455
return std::nullopt;
56+
if (auto callOp = val.getDefiningOp<tt::CallOp>())
57+
return std::nullopt;
5558
if (auto advanceOp = val.getDefiningOp<tt::AdvanceOp>())
5659
return findDefiningMakeTensorPtrOp(advanceOp.getPtr());
5760
if (auto makePtrOp = val.getDefiningOp<tt::MakeTensorPtrOp>())
5861
return makePtrOp;
5962
if (auto opRes = dyn_cast<OpResult>(val)) {
6063
Operation *defOp = opRes.getOwner();
61-
if (auto forOp = dyn_cast<scf::ForOp>(defOp)) {
62-
Value val = forOp.getYieldedValues()[opRes.getResultNumber()];
63-
return findDefiningMakeTensorPtrOp(val);
64-
}
65-
if (auto whileOp = dyn_cast<scf::WhileOp>(defOp)) {
66-
Value val = whileOp.getYieldedValues()[opRes.getResultNumber()];
67-
return findDefiningMakeTensorPtrOp(val);
68-
}
64+
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(defOp))
65+
return findDefiningMakeTensorPtrOp(
66+
loopOp.getYieldedValues()[opRes.getResultNumber()]);
6967
if (auto ifOp = dyn_cast<scf::IfOp>(defOp)) {
7068
// Give up if the 2 possible definitions aren't the same.
7169
Region &thenRgn = ifOp.getThenRegion();
@@ -78,10 +76,10 @@ std::optional<tt::MakeTensorPtrOp> findDefiningMakeTensorPtrOp(Value val) {
7876
cast<scf::YieldOp>(elseRgn.getBlocks().front().getTerminator());
7977
Value thenVal = thenYieldOp->getOperand(opRes.getResultNumber()),
8078
elseVal = elseYieldOp->getOperand(opRes.getResultNumber());
81-
std::optional<tt::MakeTensorPtrOp> thenDef =
82-
findDefiningMakeTensorPtrOp(thenVal);
83-
std::optional<tt::MakeTensorPtrOp> elseDef =
84-
findDefiningMakeTensorPtrOp(elseVal);
79+
std::optional<tt::MakeTensorPtrOp> thenDef = findDefiningMakeTensorPtrOp(
80+
thenVal),
81+
elseDef = findDefiningMakeTensorPtrOp(
82+
elseVal);
8583
if (!thenDef || !elseDef || *thenDef != *elseDef)
8684
return std::nullopt;
8785
return thenDef;
@@ -90,10 +88,10 @@ std::optional<tt::MakeTensorPtrOp> findDefiningMakeTensorPtrOp(Value val) {
9088
// Give up if the 2 possible definitions aren't the same.
9189
Value trueVal = selectOp.getTrueValue(),
9290
falseVal = selectOp.getFalseValue();
93-
std::optional<tt::MakeTensorPtrOp> trueDef =
94-
findDefiningMakeTensorPtrOp(trueVal);
95-
std::optional<tt::MakeTensorPtrOp> falseDef =
96-
findDefiningMakeTensorPtrOp(falseVal);
91+
std::optional<tt::MakeTensorPtrOp> trueDef = findDefiningMakeTensorPtrOp(
92+
trueVal),
93+
falseDef = findDefiningMakeTensorPtrOp(
94+
falseVal);
9795
if (!trueDef || !falseDef || *trueDef != *falseDef)
9896
return std::nullopt;
9997
return trueDef;
@@ -143,8 +141,8 @@ Value getFinalValue(Value value) {
143141
assert(value && "Expecting a valid value");
144142
Operation *defOp = value.getDefiningOp();
145143
if (!defOp) {
146-
// look init values outside the loop
147-
BlockArgument blockArg = cast<BlockArgument>(value);
144+
// Look up init values outside the loop.
145+
auto blockArg = cast<BlockArgument>(value);
148146
Operation *parentOp = blockArg.getOwner()->getParentOp();
149147
if (scf::ForOp forOp = dyn_cast<scf::ForOp>(parentOp)) {
150148
if (blockArg == forOp.getInductionVar())

0 commit comments

Comments
 (0)