Skip to content

Commit 65d9862

Browse files
authored
[BACKEND] Preserve tt attrs in AddPtr combine and canonicalize (#7113)
Combine and canonicalize remove the attributes that relate to alignment so we don't get vectorization of stores in some cases, this allows them to keep it.
1 parent 116de33 commit 65d9862

File tree

7 files changed

+78
-3
lines changed

7 files changed

+78
-3
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#ifndef TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_
2+
#define TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_
3+
4+
#include "mlir/Support/LLVM.h"
5+
#include "triton/Dialect/Triton/IR/Dialect.h"
6+
7+
namespace mlir::triton {
8+
9+
// Filter out attributes from the given operation that are not present in
10+
// the allowList.
11+
[[nodiscard]] SmallVector<NamedAttribute>
12+
filterDiscardableAttrs(Operation *op, ArrayRef<StringRef> allowList);
13+
14+
} // namespace mlir::triton
15+
#endif // TRITON_DIALECT_TRITON_IR_DISCARDABLE_ATTRIBUTES_H_

lib/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_public_tablegen_target(TritonCanonicalizeIncGen)
44

55
add_triton_library(TritonIR
66
Dialect.cpp
7+
DiscardableAttributes.cpp
78
Ops.cpp
89
Traits.cpp
910
Types.cpp
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "mlir/Support/LLVM.h"
2+
#include "triton/Dialect/Triton/IR/Dialect.h"
3+
4+
namespace mlir::triton {
5+
6+
SmallVector<NamedAttribute>
7+
filterDiscardableAttrs(Operation *op, ArrayRef<StringRef> allowList) {
8+
SmallVector<NamedAttribute> propagatedAttrs;
9+
for (auto attrName : allowList) {
10+
Attribute attr = op->getDiscardableAttr(attrName);
11+
if (attr)
12+
propagatedAttrs.emplace_back(attrName, attr);
13+
}
14+
return propagatedAttrs;
15+
}
16+
17+
} // namespace mlir::triton

lib/Dialect/Triton/Transforms/Combine.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Support/LogicalResult.h"
77
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
88
#include "triton/Dialect/Triton/IR/Dialect.h"
9+
#include "triton/Dialect/Triton/IR/DiscardableAttributes.h"
910
#include "triton/Dialect/Triton/Transforms/Passes.h"
1011

1112
namespace mlir::triton {

lib/Dialect/Triton/Transforms/Combine.td

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,15 @@ def CombineDotAddFRevPattern : Pat<
3939
// Note: leave (sub %c0, %c0) canceling to ArithDialect
4040
// (ref: ArithCanonicalization.td)
4141
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
42+
43+
def CopyDiscardableAttrs: NativeCodeCallVoid<
44+
"$1.getOwner()->setDiscardableAttrs(triton::filterDiscardableAttrs($0.getOwner(), "
45+
"{\"tt.divisibility\", \"tt.contiguity\", \"tt.constancy\"}))">;
46+
4247
def CombineAddPtrPattern : Pat<
43-
(TT_AddPtrOp (TT_AddPtrOp $ptr, $idx0), $idx1),
44-
(TT_AddPtrOp $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
45-
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)]>;
48+
(TT_AddPtrOp:$src (TT_AddPtrOp $ptr, $idx0), $idx1),
49+
(TT_AddPtrOp:$dest $ptr, (Arith_AddIOp $idx0, $idx1, DefOverflow)),
50+
[(Constraint<CPred<"isAddPtrOffsetCombinable($0, $1)">> $idx0, $idx1)],
51+
[(CopyDiscardableAttrs $src, $dest)]>;
4652

4753
#endif

test/Triton/combine.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,29 @@ tt.func @test_combine_addptr_pattern(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f3
8484
tt.return %ptr1 : tensor<8x!tt.ptr<f32>>
8585
}
8686

87+
// CHECK-LABEL: @test_combine_addptr_pattern_discardableattrs
88+
tt.func @test_combine_addptr_pattern_discardableattrs(%base: !tt.ptr<f32>) -> !tt.ptr<f32> {
89+
%off0 = arith.constant 8 : i32
90+
%off1 = arith.constant 4 : i32
91+
// CHECK-NEXT: %[[cst:.*]] = arith.constant 12 : i32
92+
// CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] {tt.constancy = 8 : i32, tt.contiguity = 512 : i32, tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
93+
%ptr0 = tt.addptr %base, %off0 : !tt.ptr<f32>, i32
94+
%ptr1 = tt.addptr %ptr0, %off1 {tt.divisibility = 16 : i32, tt.constancy = 8 : i32, tt.contiguity = 512 : i32} : !tt.ptr<f32>, i32
95+
96+
tt.return %ptr1 : !tt.ptr<f32>
97+
}
98+
99+
// CHECK-LABEL: @test_combine_addptr_pattern_discardableattrs_disallowed
100+
tt.func @test_combine_addptr_pattern_discardableattrs_disallowed(%base: !tt.ptr<f32>) -> !tt.ptr<f32> {
101+
%off0 = arith.constant 8 : i32
102+
%off1 = arith.constant 4 : i32
103+
// CHECK-NEXT: %[[cst:.*]] = arith.constant 12 : i32
104+
// CHECK-NEXT: %0 = tt.addptr %{{.*}}, %[[cst]] {tt.divisibility = 16 : i32} : !tt.ptr<f32>, i32
105+
%ptr0 = tt.addptr %base, %off0 : !tt.ptr<f32>, i32
106+
%ptr1 = tt.addptr %ptr0, %off1 {tt.divisibility = 16 : i32, tt.disallowed = 8 : i32} : !tt.ptr<f32>, i32
107+
108+
tt.return %ptr1 : !tt.ptr<f32>
109+
}
87110
// CHECK-LABEL: @test_combine_addptr_pattern_i64
88111
tt.func @test_combine_addptr_pattern_i64(%base: !tt.ptr<f32>) -> tensor<8x!tt.ptr<f32>> {
89112
%off0 = arith.constant 10 : i64

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Transforms/DialectConversion.h"
1818
#include "triton/Analysis/Utility.h"
1919
#include "triton/Dialect/Triton/IR/Dialect.h"
20+
#include "triton/Dialect/Triton/IR/DiscardableAttributes.h"
2021
#include "triton/Dialect/Triton/IR/Types.h"
2122
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
2223
#include "llvm/ADT/STLExtras.h"
@@ -561,12 +562,20 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
561562
RewriterBase::InsertionGuard guard(rewriter);
562563
rewriter.setInsertionPoint(addPtrOp);
563564

565+
// Query all discardable attributes that we want to preserve
566+
std::array<StringRef, 3> propagateList{"tt.divisibility", "tt.contiguity",
567+
"tt.constancy"};
568+
SmallVector<NamedAttribute> propagatedAttrs =
569+
tt::filterDiscardableAttrs(addPtrOp.getOperation(), propagateList);
570+
564571
// If it is a scalar pointer update, simply bump the base pointer
565572
if (llvm::isa<tt::PointerType>(addPtrOp.getPtr().getType())) {
566573
assert(llvm::isa<IntegerType>(origOffset.getType()) &&
567574
"expected offset to be integer type");
568575
auto newAddPtrOp = rewriter.create<tt::AddPtrOp>(
569576
curLoc, fatPtrBase.getType(), fatPtrBase, origOffset);
577+
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
578+
570579
rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}});
571580
fatPtrs[{newAddPtrOp, fatPtrOffset}] =
572581
fatPtrs.at({fatPtrBase, fatPtrOffset});
@@ -581,6 +590,8 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
581590
maybeGetOrCreateScalarConstant(rewriter, curLoc, origOffset)) {
582591
tt::AddPtrOp newAddPtrOp = rewriter.create<tt::AddPtrOp>(
583592
curLoc, fatPtrBase.getType(), fatPtrBase, *scalarConst);
593+
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
594+
584595
rewriter.replaceOpWithMultiple(addPtrOp, {{newAddPtrOp, fatPtrOffset}});
585596
// If we are updating the tensor pointer with a constant value, we can
586597
// propagate the attributes of the tensor pointer to the fat pointer.
@@ -596,6 +607,7 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
596607

597608
auto newAddPtrOp = rewriter.create<tt::AddPtrOp>(
598609
curLoc, fatPtrBase.getType(), fatPtrBase, uniformOffset);
610+
newAddPtrOp->setDiscardableAttrs(propagatedAttrs);
599611

600612
// Vector offset update (if any): bump the tensor offset
601613
bool canNarrow = fatPtrs.at({fatPtrBase, fatPtrOffset}).canNarrow;

0 commit comments

Comments
 (0)