Skip to content

Commit 737e7b3

Browse files
authored
[triton-raise-block-ptr]: Fix lowering of tt.addptr where ptr operand is yielded by a previous tt.advance operation (#3296)
Fixed issue #3295 --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent c74b88e commit 737e7b3

File tree

2 files changed

+112
-47
lines changed

2 files changed

+112
-47
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s
2+
3+
module {
4+
tt.func @kernel(
5+
%arg0 : !tt.ptr<bf16>,
6+
%arg1 : i32
7+
)
8+
{
9+
%0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32>
10+
// offset = 0, size = 4, stride = 1
11+
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
12+
// offset = [0,0], size = [4,1], stride = [1,0]
13+
%2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32>
14+
// offset = [0,0], size = [4,256], stride = [1,0]
15+
%arg1splat = tt.splat %arg1 : i32 -> tensor<4x256xi32>
16+
%offset3 = arith.addi %2, %arg1splat : tensor<4x256xi32>
17+
// offset = [%arg1,0], size = [4,256], stride = [1,0]
18+
%3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32>
19+
// offset = 0, size = 256, stride = 1
20+
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
21+
// offset = [0,0], size = [1,256], stride = [0,1]
22+
%5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32>
23+
// offset = [0,0], size = [4,256], stride = [0,1]
24+
%6 = arith.constant 5 : i32
25+
%splat6 = tt.splat %6 : i32 -> tensor<4x256xi32>
26+
%scale5 = arith.muli %5, %splat6 : tensor<4x256xi32>
27+
// offset = [0,0], size = [4,256], stride = [0,5]
28+
%7 = arith.addi %offset3, %scale5: tensor<4x256xi32>
29+
// offset = [%arg1, 0], size = [4, 256], stride = [1, 5]
30+
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
31+
%9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
32+
// source: %arg0, offset = [%arg1, 0], size = [4, 256], stride = [1, 5]
33+
%10 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr<bf16>>
34+
%12 = tt.addptr %9, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
35+
// source: %arg0, offset = [%arg1+%arg1, 0], size = [4, 256], stride = [2, 10]
36+
%13 = tt.load %12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr<bf16>>
37+
%14 = arith.addf %10, %13 : tensor<4x256xbf16>
38+
%16 = tt.addptr %12, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
39+
// source: %arg0, offset = [%arg1+%arg1+%arg1, 0], size = [4, 256], stride = [3, 15]
40+
tt.store %16, %14 : tensor<4x256x!tt.ptr<bf16>>
41+
tt.return
42+
}
43+
}
44+
45+
// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: i32) {
46+
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
47+
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
48+
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
49+
// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
50+
// CHECK: [[VAR_0_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
51+
// CHECK: [[VAR_1_:%.+]] = tt.load [[VAR_0_]] : !tt.ptr<tensor<4x256xbf16>>
52+
// CHECK: [[VAR_2_:%.+]] = tt.advance [[VAR_0_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : <tensor<4x256xbf16>>
53+
// CHECK: [[VAR_3_:%.+]] = tt.load [[VAR_2_]] : !tt.ptr<tensor<4x256xbf16>>
54+
// CHECK-DAG: [[VAR_4_:%.+]] = arith.addf [[VAR_1_]], [[VAR_3_]] : tensor<4x256xbf16>
55+
// CHECK-DAG: [[VAR_5_:%.+]] = tt.advance [[VAR_2_]], {{\[}}[[PARAM_1_]], [[CST_0_i32]]] : <tensor<4x256xbf16>>
56+
// CHECK: tt.store [[VAR_5_]], [[VAR_4_]] : !tt.ptr<tensor<4x256xbf16>>
57+
// CHECK: tt.return
58+
// CHECK: }

third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp

Lines changed: 54 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/IR/Verifier.h"
1313
#include "mlir/Support/LLVM.h"
1414
#include "triton/Dialect/Triton/IR/Dialect.h"
15+
#include "triton/Dialect/Triton/IR/Types.h"
1516
#include "llvm/ADT/TypeSwitch.h"
1617
#include "llvm/Support/Debug.h"
1718
#include "llvm/Support/ErrorHandling.h"
@@ -373,6 +374,8 @@ struct PtrState {
373374

374375
Value createTTAdvanceOp(Value ptr, tt::MakeTensorPtrOp makeTPtrOp,
375376
OpBuilder &builder, Location loc) const {
377+
assert(triton::isTensorPointerType(ptr.getType()) &&
378+
"Expecting a block ptr");
376379
SmallVector<Value> newOffsets;
377380
for (const auto &[offset, stride] :
378381
llvm::zip(offsets, makeTPtrOp.getStrides()))
@@ -676,44 +679,13 @@ struct TritonRaiseBlockPointer
676679
}
677680

678681
LogicalResult rewriteAddPtrOp(tt::AddPtrOp op) {
679-
LLVM_DEBUG(llvm::dbgs() << "Rewriting: " << *op << "\n");
680-
681682
OpBuilder builder(op);
682683
Location loc = op.getLoc();
683684
Value ptr = op.getPtr();
684685

685-
auto fillOffsets = [&](Value offset, unsigned rank,
686-
SmallVector<Value> &offsets) {
687-
switch (rank) {
688-
case 1:
689-
offsets.push_back(offset);
690-
break;
691-
case 2:
692-
offsets.push_back(
693-
findOrCreateConstant(loc, 0, offsetBitwidth, builder));
694-
offsets.push_back(offset);
695-
break;
696-
default:
697-
llvm_unreachable("unexpected rank");
698-
}
699-
};
700-
701-
auto getConstantValue = [](arith::ConstantOp cstOp) {
702-
TypedAttr cstVal = cstOp.getValue();
703-
APInt val;
704-
if (auto attr = dyn_cast<DenseIntElementsAttr>(cstVal))
705-
val = attr.getSplatValue<APInt>();
706-
else if (auto attr = dyn_cast<IntegerAttr>(cstVal))
707-
val = attr.getValue();
708-
else
709-
assert(false && "unexpected constant type");
710-
711-
return val;
712-
};
713-
714-
// If the ptr has already been mapped (i.e. rewritten into a block
715-
// pointer), rewrite the AddPtrOp using and AdvanceOp.
686+
// Case 1: the ptr has been already been mapped.
716687
if (Value mappedV = ptrMap.lookupOrNull(ptr)) {
688+
// Case 1a: the ptr has been mapped to a make_tensor_ptr operation.
717689
if (auto makeTPtrOp = mappedV.getDefiningOp<tt::MakeTensorPtrOp>()) {
718690
PtrState state;
719691
if (failed(visitOperand(op.getOffset(), state, loc, builder)))
@@ -726,20 +698,60 @@ struct TritonRaiseBlockPointer
726698
cleanUp.insert(op);
727699
ptrMap.map(op.getResult(), advanceOp);
728700

729-
LLVM_DEBUG(llvm::dbgs()
730-
<< "Rewrote:\n\t" << op << "\nto:\n\t" << advanceOp << "\n");
701+
LLVM_DEBUG({
702+
auto modOp =
703+
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
704+
llvm::dbgs() << "Module:\n" << modOp << "\n";
705+
llvm::dbgs() << "Rewrote:\n\t" << op << "\nto:\n\t" << advanceOp
706+
<< "\n";
707+
});
708+
709+
return success();
710+
}
711+
712+
// Case 1b: the ptr has been mapped to a tt.advance operation.
713+
if (auto advanceOp = mappedV.getDefiningOp<tt::AdvanceOp>()) {
714+
PtrState state;
715+
if (failed(visitOperand(op.getOffset(), state, loc, builder)))
716+
return failure();
717+
718+
// Skip through a chain of tt.advance operations...
719+
Value ptr = advanceOp.getPtr();
720+
while (auto advanceOp = ptr.getDefiningOp<tt::AdvanceOp>())
721+
ptr = advanceOp.getPtr();
722+
723+
// ... until we find the make_tensor_ptr operation defining the block
724+
// ptr feeding the first tt.advance operation.
725+
auto makeTPtrOp = ptr.getDefiningOp<tt::MakeTensorPtrOp>();
726+
assert(makeTPtrOp && "Expected a MakeTensorPtrOp");
727+
728+
Value newAdvanceOp = state.createTTAdvanceOp(advanceOp.getResult(),
729+
makeTPtrOp, builder, loc);
730+
731+
cleanUp.insert(op);
732+
ptrMap.map(op.getResult(), newAdvanceOp);
733+
734+
LLVM_DEBUG({
735+
llvm::dbgs() << "Rewrote:\n\t" << op << "\nto:\n\t" << newAdvanceOp
736+
<< "\n";
737+
auto modOp =
738+
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
739+
llvm::dbgs() << "Module:\n" << modOp << "\n";
740+
});
741+
731742
return success();
732-
} else {
733-
llvm_unreachable("Did not find tt::MakeTensorPtrOp");
734743
}
744+
745+
llvm_unreachable("Unexpected mappedV defining operation");
735746
}
736747

748+
// Case 2: the ptr has not previously been mapped.
737749
// If the addptr operation increments a scalar pointer, give up.
738750
Value result = op.getResult();
739751
if (!isa<RankedTensorType>(result.getType()))
740752
return failure();
741753

742-
// Otherwise, rewrite the AddPtrOp using PtrState.
754+
// Otherwise, rewrite the AddPtrOp.
743755
PtrState state;
744756
if (failed(visitOperandAddptr(op, state, loc, builder)))
745757
return failure();
@@ -750,16 +762,11 @@ struct TritonRaiseBlockPointer
750762
Value makePtrOp = state.createTTMakeTensorPtrOp(builder, loc);
751763
knownPtrs[makePtrOp] = std::move(state);
752764

753-
ptrMap.map(result, makePtrOp);
754-
755-
LLVM_DEBUG(llvm::dbgs()
756-
<< "Rewrote:\n\t" << op << "\nto:\n\t" << makePtrOp << "\n");
757-
758-
// AddPtrOps that have been rewritten and no longer used in the code must
759-
// be removed in the pass to avoid type matching issue.
760765
cleanUp.insert(op);
766+
ptrMap.map(result, makePtrOp);
761767

762768
LLVM_DEBUG({
769+
llvm::dbgs() << "Rewrote:\n\t" << op << "\nto:\n\t" << makePtrOp << "\n";
763770
auto modOp =
764771
builder.getBlock()->getParentOp()->getParentOfType<ModuleOp>();
765772
llvm::dbgs() << "Module:\n" << modOp << "\n";
@@ -915,8 +922,8 @@ struct TritonRaiseBlockPointer
915922
}
916923

917924
// This operand must be an iter-arg of an inner-loop in a multiple-level
918-
// nested loop, which means its PtrState must have already been populated
919-
// during rewriteForOp of the parent loop.
925+
// nested loop, which means its PtrState must have already been
926+
// populated during rewriteForOp of the parent loop.
920927
state = knownPtrs[operand];
921928
return success();
922929
}

0 commit comments

Comments
 (0)