Skip to content

Commit f5a5b0f

Browse files
authored
[AMD] Disable canonicalize-pointer if ub.poison (#7092)
Temp fix for issue with loop fusion + pointer canonicalization until all offsets can be assumed to be i32.
1 parent c6d9624 commit f5a5b0f

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1468,3 +1468,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
14681468
// CHECK: %[[VAL_21:.*]] = tt.load %[[VAL_20]], %[[VAL_11]] : tensor<256x!tt.ptr<i32>>
14691469
// CHECK: tt.return %[[VAL_21]] : tensor<256xi32>
14701470
// CHECK: }
1471+
1472+
// -----
1473+
1474+
module attributes {"ttg.num-warps" = 4 : i32} {
1475+
tt.func @ifOpPoison(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> {
1476+
%c1024_i32 = arith.constant 1024 : i32
1477+
// expected-remark@+1 {{skipping canonicalize-pointers due to ub.poison}}
1478+
%poison = ub.poison : tensor<1024x!tt.ptr<f32>>
1479+
%0 = tt.get_program_id x : i32
1480+
%1 = arith.muli %0, %c1024_i32 : i32
1481+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
1482+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
1483+
%4 = arith.addi %3, %2 : tensor<1024xi32>
1484+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
1485+
%6 = scf.if %arg2 -> (tensor<1024x!tt.ptr<f32>>) {
1486+
%8 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
1487+
scf.yield %8 : tensor<1024x!tt.ptr<f32>>
1488+
} else {
1489+
scf.yield %poison : tensor<1024x!tt.ptr<f32>>
1490+
}
1491+
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
1492+
tt.return %7 : tensor<1024xf32>
1493+
}
1494+
}

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
44
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
55
#include "mlir/Dialect/SCF/IR/SCF.h"
6+
#include "mlir/Dialect/UB/IR/UBOps.h"
67
#include "mlir/IR/Attributes.h"
78
#include "mlir/IR/Block.h"
89
#include "mlir/IR/BuiltinAttributes.h"
@@ -634,9 +635,6 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
634635
}
635636
};
636637

637-
using ConversionCallbackFn =
638-
std::function<std::optional<LogicalResult>(Type, SmallVectorImpl<Type> &)>;
639-
640638
/// Rewrite init args and result type and bb args.
641639
class ConvertSCFForOp : public PointerCanonicalizationPattern<scf::ForOp> {
642640
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
@@ -1464,6 +1462,12 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
14641462
});
14651463

14661464
auto func = getOperation();
1465+
auto walkResult = func->walk<WalkOrder::PreOrder>([](ub::PoisonOp op) {
1466+
op.emitRemark("skipping canonicalize-pointers due to ub.poison");
1467+
return WalkResult::interrupt();
1468+
});
1469+
if (walkResult.wasInterrupted())
1470+
return;
14671471

14681472
FatPointers fatPrs;
14691473
PatternRewriter rewriter(&getContext());

0 commit comments

Comments
 (0)