Skip to content

Commit 6c3d943

Browse files
authored
[AMD] Fix pointer canonicalizer when propagating discardable attrs (#7242)
Propagating divisibility and other discardable atts needs to make sure that the rank for source/destination instructions matches. When the ranks doesn't match it isn't possible to trivially propagate those attributes. Allow propagation of attributes for same rank and also between rank 1 and scalar.
1 parent cb30573 commit 6c3d943

File tree

2 files changed

+62
-3
lines changed

2 files changed

+62
-3
lines changed

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,3 +1516,54 @@ module attributes {"ttg.num-warps" = 4 : i32} {
15161516
tt.return %7 : tensor<1024xf32>
15171517
}
15181518
}
1519+
// -----
1520+
1521+
module attributes {"ttg.num-warps" = 4 : i32} {
1522+
tt.func @propagate_divisibility(%arg0: !tt.ptr<f32>) -> tensor<1024xf32> {
1523+
%c1024_i32 = arith.constant 1024 : i32
1524+
%0 = tt.get_program_id x : i32
1525+
%1 = arith.muli %0, %c1024_i32 : i32
1526+
%2 = tt.splat %1 : i32 -> tensor<1024xi32>
1527+
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
1528+
%4 = tt.addptr %3, %2 {tt.divisibility = 16 : i32, misc.misc = 3 : i32} : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
1529+
%5 = tt.load %4 : tensor<1024x!tt.ptr<f32>>
1530+
tt.return %5 : tensor<1024xf32>
1531+
}
1532+
}
1533+
1534+
// CHECK-LABEL: tt.func @propagate_divisibility(
1535+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024xf32> {
1536+
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1537+
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1538+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1539+
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
1540+
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
1541+
// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x!tt.ptr<f32>>
1542+
// CHECK: tt.return %[[VAL_6]] : tensor<1024xf32>
1543+
// CHECK: }
1544+
1545+
// -----
1546+
1547+
module attributes {"ttg.num-warps" = 4 : i32} {
1548+
tt.func @divisiblity_changeing_dims(%arg0: !tt.ptr<f32>) -> tensor<1024x32xf32> {
1549+
%c1024_i32 = arith.constant 1024 : i32
1550+
%0 = tt.get_program_id x : i32
1551+
%1 = arith.muli %0, %c1024_i32 : i32
1552+
%2 = tt.splat %1 : i32 -> tensor<1024x32xi32>
1553+
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
1554+
%4 = tt.addptr %3, %2 {tt.divisibility = dense<[1, 16]> : tensor<2xi32>} : tensor<1024x32x!tt.ptr<f32>>, tensor<1024x32xi32>
1555+
%5 = tt.load %4 : tensor<1024x32x!tt.ptr<f32>>
1556+
tt.return %5 : tensor<1024x32xf32>
1557+
}
1558+
}
1559+
1560+
// CHECK-LABEL: tt.func @divisiblity_changeing_dims(
1561+
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<1024x32xf32> {
1562+
// CHECK: %[[VAL_1:.*]] = arith.constant 1024 : i32
1563+
// CHECK: %[[VAL_2:.*]] = tt.get_program_id x : i32
1564+
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_2]], %[[VAL_1]] : i32
1565+
// CHECK: %[[VAL_4:.*]] = tt.addptr %[[VAL_0]], %[[VAL_3]] : !tt.ptr<f32>, i32
1566+
// CHECK: %[[VAL_5:.*]] = tt.splat %[[VAL_4]] : !tt.ptr<f32> -> tensor<1024x32x!tt.ptr<f32>>
1567+
// CHECK: %[[VAL_6:.*]] = tt.load %[[VAL_5]] : tensor<1024x32x!tt.ptr<f32>>
1568+
// CHECK: tt.return %[[VAL_6]] : tensor<1024x32xf32>
1569+
// CHECK: }

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -567,14 +567,22 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
567567
"tt.constancy"};
568568
SmallVector<NamedAttribute> propagatedAttrs =
569569
tt::filterDiscardableAttrs(addPtrOp.getOperation(), propagateList);
570+
auto currPtrTy = llvm::dyn_cast<RankedTensorType>(addPtrOp.getType());
571+
int currPtrRank = currPtrTy ? currPtrTy.getRank() : 1;
572+
auto doSetDiscardableAttrs = [&](tt::AddPtrOp newAddPtrOp) {
573+
auto newPtrTy = llvm::dyn_cast<RankedTensorType>(newAddPtrOp.getType());
574+
int newPtrRank = newPtrTy ? newPtrTy.getRank() : 1;
575+
if (newPtrRank == currPtrRank)
576+
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
577+
};
570578

571579
// If it is a scalar pointer update, simply bump the base pointer
572580
if (llvm::isa<tt::PointerType>(addPtrOp.getPtr().getType())) {
573581
assert(llvm::isa<IntegerType>(origOffset.getType()) &&
574582
"expected offset to be integer type");
575583
auto newAddPtrOp = rewriter.create<tt::AddPtrOp>(
576584
curLoc, fatPtrBase.getType(), fatPtrBase, origOffset);
577-
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
585+
doSetDiscardableAttrs(newAddPtrOp);
578586

579587
rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}});
580588
fatPtrs[{newAddPtrOp, fatPtrOffset}] =
@@ -590,7 +598,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
590598
maybeGetOrCreateScalarConstant(rewriter, curLoc, origOffset)) {
591599
tt::AddPtrOp newAddPtrOp = rewriter.create<tt::AddPtrOp>(
592600
curLoc, fatPtrBase.getType(), fatPtrBase, *scalarConst);
593-
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
601+
doSetDiscardableAttrs(newAddPtrOp);
594602

595603
rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}});
596604
// If we are updating the tensor pointer with a constant value, we can
@@ -607,7 +615,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
607615

608616
auto newAddPtrOp = rewriter.create<tt::AddPtrOp>(
609617
curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset);
610-
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
618+
doSetDiscardableAttrs(newAddPtrOp);
611619

612620
// Vector offset update (if any): bump the tensor offset
613621
bool canNarrow = fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow;

0 commit comments

Comments
 (0)