Skip to content

Commit 762a7d1

Browse files
authored
[AMD][CanonicalizePointers] Propagate the attributes during the rewrites (#4815)
This is fixing the issue where the IR was not properly vectorized in Triton (and we were relying on a backend pass which was not always able to do the right thing). The general issue was that we were not propagating the attributes of the operation we were rewriting. The specific issue was that block argument attributes are in the for of `tt.property_argi` for the i-th block argument, so we needed to do a bit more work to propagate those correctly. This PR is trying to address this problem by adding a vector of attributes to the `FatPtr` structure. We are not propagating the attributes, but only setting them whenever the IR had them set.
1 parent 6c3e3ae commit 762a7d1

File tree

3 files changed

+116
-11
lines changed

3 files changed

+116
-11
lines changed

lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
461461
auto newForOp =
462462
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
463463
forOp.getStep(), newLoopArg);
464+
newForOp->setAttrs(forOp->getAttrs());
464465
// When there are no iter args, the loop body terminator will be created.
465466
// Since we always create it below, remove the terminator if it was created.
466467
if (!newForOp.getBody()->empty())

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,3 +546,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
546546
tt.return
547547
}
548548
}
549+
550+
// -----
551+
552+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
553+
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
554+
// CHECK-LABEL: tt.func @forOpWithHints
555+
tt.func @forOpWithHints(%arg0: !tt.ptr<f32>, %init : tensor<1024xf32, #blocked>)-> tensor<1024xf32, #blocked>{
556+
%c0 = arith.constant 0: index
557+
%c1 = arith.constant 1 : index
558+
%c128 = arith.constant 128: index
559+
%0 = tt.get_program_id x : i32
560+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
561+
%3 = tt.splat %0 : i32 -> tensor<1024xi32, #blocked>
562+
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
563+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
564+
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
565+
%52:2 = scf.for %arg9 = %c0 to %c128 step %c1 iter_args(%arg1 = %6, %arg2 = %init) -> (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>){
566+
%9 = tt.load %arg1: tensor<1024x!tt.ptr<f32>, #blocked>
567+
// CHECK: tt.addptr {{.*}}, {{.*}} {tt.divisibility = dense<16> : tensor<1xi32>}
568+
%11 = tt.addptr %arg1, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
569+
%12 = tt.addptr %11, %3 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
570+
%10 = arith.addf %9, %arg2 : tensor<1024xf32, #blocked>
571+
scf.yield %12, %10 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>
572+
} {"tt.divisibility_arg1"=dense<[16]> : tensor<1xi32>}
573+
// CHECK: tt.divisibility_arg1
574+
// CHECK-SAME: tt.divisibility_arg4
575+
%8 = tt.addptr %52#0, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
576+
%11 = tt.load %8 : tensor<1024x!tt.ptr<f32>, #blocked>
577+
tt.return %11 : tensor<1024xf32, #blocked>
578+
}
579+
}

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#include "mlir/Dialect/Arith/IR/Arith.h"
22
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
33
#include "mlir/Dialect/SCF/IR/SCF.h"
4+
#include "mlir/IR/Attributes.h"
5+
#include "mlir/IR/Block.h"
46
#include "mlir/IR/BuiltinAttributes.h"
57
#include "mlir/IR/BuiltinOps.h"
68
#include "mlir/IR/BuiltinTypes.h"
79
#include "mlir/IR/IRMapping.h"
810
#include "mlir/IR/Matchers.h"
11+
#include "mlir/IR/OperationSupport.h"
912
#include "mlir/IR/PatternMatch.h"
1013
#include "mlir/IR/TypeUtilities.h"
1114
#include "mlir/IR/Value.h"
@@ -15,6 +18,8 @@
1518
#include "triton/Dialect/Triton/IR/Types.h"
1619
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1720
#include "llvm/ADT/STLExtras.h"
21+
#include "llvm/ADT/SmallVector.h"
22+
#include "llvm/ADT/StringRef.h"
1823
#include "llvm/ADT/TypeSwitch.h"
1924
#include "llvm/Support/Casting.h"
2025
#include "llvm/Support/Debug.h"
@@ -85,6 +90,8 @@ class PointerCanonicalizer {
8590
Value offset;
8691
// Flag to express if we can narrow the uses of the offset down to 32 bits
8792
bool canNarrow = false;
93+
// Collection of attributes that need to be applied to the pointer
94+
SmallVector<NamedAttribute> attributes;
8895

8996
// Utility copy functions
9097
FatPtr copy(Value newBasePtr, Value newOffset) {
@@ -96,6 +103,11 @@ class PointerCanonicalizer {
96103
FatPtr copyWithOffset(Value newBase) {
97104
return FatPtr{newBase, offset, canNarrow};
98105
}
106+
// Attribute functions
107+
void setAttr(NamedAttribute attr) { attributes.push_back(attr); }
108+
void setAttrs(ArrayRef<NamedAttribute> attrs) {
109+
llvm::append_range(attributes, attrs);
110+
}
99111
};
100112

101113
// Rewrite any operation that needs a pointer
@@ -104,8 +116,15 @@ class PointerCanonicalizer {
104116
// Start from an argument of a function and propagate its fat pointers
105117
LogicalResult rewritePointer(Value argPtr);
106118

119+
// Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is
120+
// obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape
121+
// and adding the offset to it.
107122
Value createTensorPointer(FatPtr fatPtr, Location loc);
108123

124+
// Push the attributes of the given operation `op` to the fat pointer
125+
// corresponding to `val`
126+
void collectFatPointerAttributes(Operation *op, Value val);
127+
109128
// Rewrite a given function, canonicalizing the different pointer arguments of
110129
// the region
111130
LogicalResult rewriteFunction(triton::FuncOp funcOp);
@@ -269,6 +288,46 @@ Value createTensorZero(IRRewriter &rw, Location loc, RankedTensorType type) {
269288

270289
} // namespace
271290

291+
void PointerCanonicalizer::collectFatPointerAttributes(Operation *op,
292+
Value val) {
293+
auto addBlockArgumentAttr = [&](BlockArgument arg) {
294+
// If the value is a block parameter, the operation can specify
295+
// an attribute for the given parameter by using `tt.property_argi`
296+
// where `argi` refers to the arg number of the given parameter.
297+
// So we need to iterate through the property, find the right one
298+
// and push the property onto the pointers attributes.
299+
llvm::SmallString<8> scratchStr;
300+
for (NamedAttribute namedAttr : op->getAttrs()) {
301+
scratchStr.clear();
302+
llvm::raw_svector_ostream sstream(scratchStr);
303+
sstream << "_arg" << arg.getArgNumber();
304+
StringRef attrName = namedAttr.getName().getValue();
305+
if (attrName.ends_with(scratchStr)) {
306+
StringRef newAttrName = attrName.drop_back(scratchStr.size());
307+
namedAttr.setName(rewriter.getStringAttr(newAttrName));
308+
pointers[val].setAttr(namedAttr);
309+
// Propagate the argument to the offset if it is also a block argument
310+
if (auto offsetArg = dyn_cast<BlockArgument>(pointers[val].offset)) {
311+
scratchStr.clear();
312+
sstream << newAttrName << "_arg" << offsetArg.getArgNumber();
313+
op->setAttr(scratchStr, namedAttr.getValue());
314+
}
315+
}
316+
}
317+
};
318+
319+
// If it is the i-th block argument, then look if the operation defined some
320+
// _argi attribute and add it to the fat pointer attributes
321+
if (auto arg = dyn_cast<BlockArgument>(val)) {
322+
addBlockArgumentAttr(arg);
323+
return;
324+
}
325+
326+
// Otherwise add the attributes of the operation to the fat pointer
327+
for (NamedAttribute attr : op->getAttrs())
328+
pointers[val].setAttr(attr);
329+
}
330+
272331
// Offset extraction logic for an addition op:
273332
// decompose(A+B) = {U(A)+U(B), NU(A)+NU(B)}
274333
std::pair<Value, Value>
@@ -372,9 +431,6 @@ PointerCanonicalizer::decomposeOffsetFromExpr(Location loc, Value expr,
372431
return offsets;
373432
}
374433

375-
// Create a tensor pointer from a fat pointer `fatPtr`. The tensor pointer is
376-
// obtained by splatting the `fatPtr.basePtr` using the `fatPtr.offset` shape
377-
// and adding the offset to it.
378434
Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) {
379435
Value basePtr = fatPtr.basePtr;
380436
Value offset = fatPtr.offset;
@@ -390,9 +446,12 @@ Value PointerCanonicalizer::createTensorPointer(FatPtr fatPtr, Location loc) {
390446
Value tensorPtr =
391447
rewriter.create<triton::SplatOp>(loc, tensorPtrType, basePtr);
392448

393-
tensorPtr =
449+
auto addPtrOp =
394450
rewriter.create<triton::AddPtrOp>(loc, tensorPtrType, tensorPtr, offset);
395-
return tensorPtr;
451+
452+
addPtrOp->setAttrs(fatPtr.attributes);
453+
454+
return addPtrOp.getResult();
396455
}
397456

398457
// Rewrite a memory operation
@@ -477,6 +536,9 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
477536
newPtr = rewriter.create<triton::AddPtrOp>(curLoc, newPtr.getType(), newPtr,
478537
scalarConst);
479538
pointers[nextPtr] = fatPtr.copyWithOffset(newPtr);
539+
// If we are updating the tensor pointer with a uniform value, we can
540+
// propagate the attributes of the tensor pointer to the fat pointer.
541+
pointers[nextPtr].setAttrs(fatPtr.attributes);
480542
opToDelete.insert(addPtrOp);
481543
return success();
482544
}
@@ -496,6 +558,7 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
496558
Value fatPtrOffset = fatPtr.offset;
497559
bool canNarrow = fatPtr.canNarrow;
498560
Value newOffset = fatPtrOffset;
561+
bool propagateAtrs = true;
499562
if (!isZeroConst(nonUniformOffset)) {
500563
Type addPtrOffsetType = getElementTypeOrSelf(nonUniformOffset);
501564
canNarrow = canNarrow && canNarrowOffset(fatPtrOffset, nonUniformOffset);
@@ -507,9 +570,15 @@ LogicalResult PointerCanonicalizer::rewriteAddPtrOp(triton::AddPtrOp addPtrOp,
507570

508571
newOffset =
509572
rewriter.create<arith::AddIOp>(curLoc, nonUniformOffset, fatPtrOffset);
573+
propagateAtrs = false;
510574
}
511575
opToDelete.insert(addPtrOp);
512576
pointers[nextPtr] = FatPtr{newPtr, newOffset, canNarrow};
577+
578+
// If we are updating the tensor pointer with a uniform value, we can
579+
// propagate the attributes of the tensor pointer to the fat pointer.
580+
if (propagateAtrs)
581+
pointers[nextPtr].setAttrs(fatPtr.attributes);
513582
return success();
514583
}
515584

@@ -537,9 +606,12 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,
537606
// This is making sure we visit the uses within the forOp region
538607
Value arg = newForOp.getTiedLoopRegionIterArg(forOperand);
539608
size_t numIterArgs = newForOp.getNumRegionIterArgs();
540-
pointers[arg] =
541-
FatPtr{newForOp.getRegionIterArg(numIterArgs - 2),
542-
newForOp.getRegionIterArg(numIterArgs - 1), fatPtr.canNarrow};
609+
pointers[arg] = fatPtr.copy(newForOp.getRegionIterArg(numIterArgs - 2),
610+
newForOp.getRegionIterArg(numIterArgs - 1));
611+
612+
// Collect attributes before continuing the visit
613+
collectFatPointerAttributes(newForOp, arg);
614+
543615
for (OpOperand &use : arg.getUses())
544616
queue.push_back(&use);
545617

@@ -548,7 +620,6 @@ LogicalResult PointerCanonicalizer::rewriteForOp(scf::ForOp forOp,
548620
size_t numResults = newForOp->getNumResults();
549621
pointers[nextPtr] = fatPtr.copy(newForOp->getResult(numResults - 2),
550622
newForOp.getResult(numResults - 1));
551-
552623
opToDelete.insert(forOp);
553624
return success();
554625
}
@@ -864,11 +935,13 @@ LogicalResult PointerCanonicalizer::rewritePointer(Value argPtr) {
864935
res = materializeFatPointer(op, curLoc, curOperand->get());
865936
});
866937

867-
// Keep propagating the fat pointer down the IR
868-
if (nextPtr)
938+
// Collect the attributes and Keep propagating the fat pointer down the IR
939+
if (nextPtr) {
940+
collectFatPointerAttributes(curOp, nextPtr);
869941
for (OpOperand &use : nextPtr.getUses())
870942
if (!opToDelete.contains(use.getOwner()))
871943
queue.push_back(&use);
944+
}
872945
}
873946
return success();
874947
}

0 commit comments

Comments
 (0)