Skip to content

Commit c2908db

Browse files
committed
Address code review comments
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent edfef77 commit c2908db

File tree

2 files changed

+38
-25
lines changed

2 files changed

+38
-25
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
66
%c1_i32 = arith.constant 1 : i32
77
%c2_i32 = arith.constant 2 : i32
88
%c1_i64 = arith.constant 1 : i64
9-
%c2_i64 = arith.constant 2 : i64
109
%c4_i64 = arith.constant 4 : i64
10+
%c64_i64 = arith.constant 4 : i64
1111
%c1024_i64 = arith.constant 1024 : i64
1212
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
13-
%0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
13+
%0 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
1414
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
1515
%3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>>
1616
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
@@ -20,16 +20,17 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
2020
// CHECK-LABEL: fuseLoadWithReshape1
2121
// CHECK-NOT: tt.reshape
2222
// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64
23-
// CHECK: [[MUL1:%.*]] = arith.muli %c2_i64, [[DIV]] : i64
24-
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1_i64 : i64
23+
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
24+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c4_i64_0 : i64
2525
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
2626
// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
2727
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
2828
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
29-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
30-
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2]], [[TRUNC]] : i32
29+
// CHECK: [[ADD3:%.*]] = arith.addi %c1_i32, %c32_i32 : i32
30+
// CHECK: [[TRUNC:%.*]] = arith.trunci %c4_i64_0 : i64 to i32
31+
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3]], [[TRUNC]] : i32
3132
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<32x256xbf16>) {
32-
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
33+
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>>
3334
// CHECK: scf.yield [[LOAD_B]] : tensor<32x256xbf16>
3435
// CHECK: } else {
3536
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<32x256xbf16>
@@ -71,7 +72,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
7172
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
7273
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
7374
// CHECK: scf.for
74-
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>>
75+
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] : !tt.ptr<tensor<256x32xbf16>>
7576
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
7677

7778
// -----
@@ -106,7 +107,7 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
106107
%15 = arith.extsi %M : i32 to i64
107108
%16 = arith.extsi %K : i32 to i64
108109
%17 = arith.extsi %stride_am : i32 to i64
109-
%18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %14, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>>
110+
%18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %c128_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>>
110111
%19 = arith.muli %13, %c128_i32 : i32
111112
%20 = arith.extsi %N : i32 to i64
112113
%21 = arith.extsi %stride_bk : i32 to i64
@@ -134,13 +135,15 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
134135
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
135136
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
136137
// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
137-
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %14 : i32
138+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c128_i32 : i32
138139
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
140+
// CHECK: [[CST_256:%.*]] = arith.constant 256 : i32
141+
// CHECK: [[ADD3:%.*]] = arith.addi %c128_i32, [[CST_256]] : i32
139142
// CHECK: [[TRUNC:%.*]] = arith.trunci [[EXT_M]] : i64 to i32
140-
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD2]], [[TRUNC]] : i32
143+
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3]], [[TRUNC]] : i32
141144
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
142145
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<256x32xf32>) {
143-
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
146+
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<256x32xf32>>
144147
// CHECK: scf.yield [[LOAD_A]] : tensor<256x32xf32>
145148
// CHECK: } else {
146149
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x32xf32>
@@ -153,7 +156,7 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
153156

154157
// COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops.
155158
// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
156-
tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
159+
tt.func public @fuseLoadWithReshape4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
157160
%c0_i32 = arith.constant 0 : i32
158161
%c1_i32 = arith.constant 1 : i32
159162
%c2_i32 = arith.constant 2 : i32
@@ -185,7 +188,7 @@ tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.p
185188
tt.return
186189

187190
}
188-
// CHECK-LABEL: fuseLoadWithTrans4
191+
// CHECK-LABEL: fuseLoadWithReshape4
189192
// CHECK-NOT: tt.reshape
190193
// CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
191194
// CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ scf::IfOp createIfBlock(OpBuilder &builder, Location loc, arith::CmpIOp condOp,
6666
return ifOp;
6767
}
6868

69-
scf::IfOp createCheckedLoad(OpBuilder &builder, arith::CmpIOp condOp,
69+
scf::IfOp createCheckedLoad(OpBuilder &builder, arith::CmpIOp cmpOp,
7070
tt::LoadOp loadOp) {
71-
scf::IfOp ifOp = createIfBlock(builder, loadOp.getLoc(), condOp, loadOp);
71+
scf::IfOp ifOp = createIfBlock(builder, loadOp.getLoc(), cmpOp, loadOp);
7272
loadOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) {
7373
if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner()))
7474
return yieldOp->getParentOp() != ifOp;
@@ -311,7 +311,7 @@ class FuseReshape {
311311
// strides: [50, 5, 1] -> [ 5, 1]
312312
//
313313
// Consider a load offset of [1, 11, 1], this access is clearly
314-
// out-of-bound in dim 1 (11 > 10). However, the new offset is not
314+
// out-of-bound in dim 1 (11 > 10). However, the new offset is no
315315
// longer out-of-bound (5 < 210).
316316
auto newLoadOp =
317317
cast<tt::LoadOp>(mapping.lookup(static_cast<Operation *>(loadOp)));
@@ -322,19 +322,29 @@ class FuseReshape {
322322
case 1:
323323
// intentional fall-through
324324
case 2: {
325-
SmallVector<int> newBoundaryCheck{boundaryCheck[0] - 1};
326-
if (boundaryCheck.size() == 2)
325+
SmallVector<int> newBoundaryCheck;
326+
if ((boundaryCheck[0] - 1) != 0)
327+
newBoundaryCheck.push_back((boundaryCheck[0] - 1));
328+
if (boundaryCheck.size() == 2 && (boundaryCheck[1] - 1) != 0)
327329
newBoundaryCheck.push_back(boundaryCheck[1] - 1);
328-
newLoadOp.setBoundaryCheck({newBoundaryCheck});
330+
331+
newLoadOp.setBoundaryCheck(newBoundaryCheck);
329332

330333
if (llvm::any_of(newBoundaryCheck, [&](unsigned boundIdx) {
331-
return boundIdx == newOutermostDimIdx;
334+
return boundIdx == newOutermostDimIdx + 1;
332335
})) {
333-
Value lhs = newOffsets[newOutermostDimIdx];
334-
Value rhs = shapes[newOutermostDimIdx + 1];
336+
unsigned oldIdx = newOutermostDimIdx + 1;
337+
auto tensorType = cast<RankedTensorType>(loadOp.getResult().getType());
338+
Type elemType = tensorType.getElementType();
339+
ArrayRef<int64_t> resShape = tensorType.getShape();
340+
auto add = builder.create<arith::AddIOp>(
341+
loc, offsets[oldIdx],
342+
builder.create<arith::ConstantIntOp>(loc, offsets[oldIdx].getType(),
343+
resShape[oldIdx]));
335344
auto cmpOp = builder.create<arith::CmpIOp>(
336-
loc, arith::CmpIPredicate::ult, lhs,
337-
builder.create<arith::TruncIOp>(loc, lhs.getType(), rhs));
345+
loc, arith::CmpIPredicate::ult, add,
346+
builder.create<arith::TruncIOp>(loc, add.getResult().getType(),
347+
shapes[oldIdx]));
338348
createCheckedLoad(builder, cmpOp, newLoadOp);
339349
}
340350
} break;

0 commit comments

Comments
 (0)