Skip to content

Commit 85f5eb4

Browse files
committed
Drop boundary check and make isCandidate more restrictive
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent c2908db commit 85f5eb4

File tree

2 files changed

+22
-122
lines changed

2 files changed

+22
-122
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
1212
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
1313
%0 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
1414
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
15-
%3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>>
15+
%3 = tt.load %0 {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x32x256xbf16>>
1616
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
1717
%5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
1818
tt.return
@@ -26,17 +26,8 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
2626
// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
2727
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
2828
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
29-
// CHECK: [[ADD3:%.*]] = arith.addi %c1_i32, %c32_i32 : i32
30-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c4_i64_0 : i64 to i32
31-
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3]], [[TRUNC]] : i32
32-
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<32x256xbf16>) {
33-
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>>
34-
// CHECK: scf.yield [[LOAD_B]] : tensor<32x256xbf16>
35-
// CHECK: } else {
36-
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<32x256xbf16>
37-
// CHECK: scf.yield [[ZERO]] : tensor<32x256xbf16>
38-
// CHECK: }
39-
// CHECK: tt.dot {{.*}}, [[IF_RES]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
29+
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>>
30+
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
4031

4132
// -----
4233

@@ -113,7 +104,7 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
113104
%21 = arith.extsi %stride_bk : i32 to i64
114105
%22 = tt.make_tensor_ptr %b_ptr, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x128xf32>>
115106
%accumulator:3 = scf.for %k = %c0_i32 to %K step %c32_i32 iter_args(%a_block_ptr = %18, %b_block_ptr = %22, %accumulator_0 = %cst) -> (!tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>) : i32 {
116-
%25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x256x32xf32>>
107+
%25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x256x32xf32>>
117108
%26 = tt.reshape %25 : tensor<1x256x32xf32> -> tensor<256x32xf32>
118109
%27 = tt.load %b_block_ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x128xf32>>
119110
%28 = tt.dot %26, %27, %cst, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
@@ -137,20 +128,10 @@ tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16
137128
// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
138129
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c128_i32 : i32
139130
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
140-
// CHECK: [[CST_256:%.*]] = arith.constant 256 : i32
141-
// CHECK: [[ADD3:%.*]] = arith.addi %c128_i32, [[CST_256]] : i32
142-
// CHECK: [[TRUNC:%.*]] = arith.trunci [[EXT_M]] : i64 to i32
143-
// CHECK: [[COND:%.*]] = arith.cmpi ult, [[ADD3]], [[TRUNC]] : i32
144131
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
145-
// CHECK: [[IF_RES:%.*]] = scf.if [[COND]] -> (tensor<256x32xf32>) {
146132
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<256x32xf32>>
147-
// CHECK: scf.yield [[LOAD_A]] : tensor<256x32xf32>
148-
// CHECK: } else {
149-
// CHECK: [[ZERO:%.*]] = arith.constant dense<0.000000e+00> : tensor<256x32xf32>
150-
// CHECK: scf.yield [[ZERO]] : tensor<256x32xf32>
151-
// CHECK: }
152-
// CHECK: tt.dot [[IF_RES]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
153-
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
133+
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
134+
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
154135

155136
// -----
156137

@@ -186,7 +167,6 @@ tt.func public @fuseLoadWithReshape4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt
186167
scf.yield %add : i32
187168
}
188169
tt.return
189-
190170
}
191171
// CHECK-LABEL: fuseLoadWithReshape4
192172
// CHECK-NOT: tt.reshape

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 16 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -24,59 +24,6 @@ namespace mlir::triton::intel {
2424

2525
namespace {
2626

27-
scf::IfOp createIfBlock(OpBuilder &builder, Location loc, arith::CmpIOp condOp,
28-
tt::LoadOp loadOp) {
29-
assert(isa<RankedTensorType>(loadOp.getType()) &&
30-
"Unexpected load result type");
31-
32-
auto tensorType = cast<RankedTensorType>(loadOp.getType());
33-
assert(tensorType.getShape().size() == 2);
34-
Type elemType = tensorType.getElementType();
35-
36-
builder.setInsertionPointAfter(loadOp);
37-
auto ifOp = builder.create<scf::IfOp>(loc, tensorType, condOp, true, true);
38-
loadOp->moveBefore(ifOp.thenBlock(), ifOp.thenBlock()->end());
39-
builder.setInsertionPointAfter(loadOp);
40-
builder.create<scf::YieldOp>(loc, loadOp->getResult(0));
41-
42-
builder.setInsertionPointToStart(ifOp.elseBlock());
43-
tt::PaddingOption padding = (!loadOp.getPadding())
44-
? tt::PaddingOption::PAD_ZERO
45-
: *loadOp.getPadding();
46-
DenseElementsAttr denseAttr = nullptr;
47-
switch (padding) {
48-
case tt::PaddingOption::PAD_ZERO: {
49-
denseAttr = DenseElementsAttr::get(cast<ShapedType>(tensorType),
50-
builder.getZeroAttr(elemType));
51-
} break;
52-
case tt::PaddingOption::PAD_NAN: {
53-
assert(elemType.isF128() && "Expecting a floating point type");
54-
auto NaNVal =
55-
APFloat::getNaN(cast<FloatType>(elemType).getFloatSemantics());
56-
denseAttr = DenseElementsAttr::get(cast<ShapedType>(tensorType),
57-
builder.getFloatAttr(elemType, NaNVal));
58-
} break;
59-
default:
60-
llvm_unreachable("Unhandled padding kind");
61-
}
62-
assert(denseAttr && "Expecting a valid attribute");
63-
64-
Value other = builder.create<arith::ConstantOp>(loc, tensorType, denseAttr);
65-
builder.create<scf::YieldOp>(loc, other);
66-
return ifOp;
67-
}
68-
69-
scf::IfOp createCheckedLoad(OpBuilder &builder, arith::CmpIOp cmpOp,
70-
tt::LoadOp loadOp) {
71-
scf::IfOp ifOp = createIfBlock(builder, loadOp.getLoc(), cmpOp, loadOp);
72-
loadOp->replaceUsesWithIf(ifOp, [&](OpOperand &operand) {
73-
if (auto yieldOp = dyn_cast<scf::YieldOp>(operand.getOwner()))
74-
return yieldOp->getParentOp() != ifOp;
75-
return true;
76-
});
77-
return ifOp;
78-
};
79-
8027
// Transform:
8128
// %one = arith.constant 1 : i64
8229
// %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
@@ -298,24 +245,12 @@ class FuseReshape {
298245
propagateToUsers(ptr, chain, mapping);
299246
cleanUp.insert(makeTensorPtrOp);
300247

301-
// We have collapsed 2 dimensions into one, therefore we might have to
302-
// materialize the boundary check for the new collapsed dimension. There
303-
// are 2 possibilities:
304-
// a) if the load checks only the innermost dimension, we are ok because
305-
// we haven't collapsed that dimension
306-
// b) if the load check the new outermost dimension the boundary check
307-
// on the load is not sufficient and we have to materialize the
308-
// correct boundary check. Example:
309-
// OLD PTR NEW PTR
310-
// shape: [20, 10, 5] -> [210, 5]
311-
// strides: [50, 5, 1] -> [ 5, 1]
312-
//
313-
// Consider a load offset of [1, 11, 1], this access is clearly
314-
// out-of-bound in dim 1 (11 > 10). However, the new offset is no
315-
// longer out-of-bound (5 < 210).
248+
// We have collapsed 2 dimensions into one, therefore we need to adjust the
249+
// boundary check of the new load.
316250
auto newLoadOp =
317251
cast<tt::LoadOp>(mapping.lookup(static_cast<Operation *>(loadOp)));
318252
ArrayRef<int> boundaryCheck = newLoadOp.getBoundaryCheck();
253+
319254
switch (boundaryCheck.size()) {
320255
case 0:
321256
break;
@@ -327,26 +262,7 @@ class FuseReshape {
327262
newBoundaryCheck.push_back((boundaryCheck[0] - 1));
328263
if (boundaryCheck.size() == 2 && (boundaryCheck[1] - 1) != 0)
329264
newBoundaryCheck.push_back(boundaryCheck[1] - 1);
330-
331265
newLoadOp.setBoundaryCheck(newBoundaryCheck);
332-
333-
if (llvm::any_of(newBoundaryCheck, [&](unsigned boundIdx) {
334-
return boundIdx == newOutermostDimIdx + 1;
335-
})) {
336-
unsigned oldIdx = newOutermostDimIdx + 1;
337-
auto tensorType = cast<RankedTensorType>(loadOp.getResult().getType());
338-
Type elemType = tensorType.getElementType();
339-
ArrayRef<int64_t> resShape = tensorType.getShape();
340-
auto add = builder.create<arith::AddIOp>(
341-
loc, offsets[oldIdx],
342-
builder.create<arith::ConstantIntOp>(loc, offsets[oldIdx].getType(),
343-
resShape[oldIdx]));
344-
auto cmpOp = builder.create<arith::CmpIOp>(
345-
loc, arith::CmpIPredicate::ult, add,
346-
builder.create<arith::TruncIOp>(loc, add.getResult().getType(),
347-
shapes[oldIdx]));
348-
createCheckedLoad(builder, cmpOp, newLoadOp);
349-
}
350266
} break;
351267
default:
352268
// Note: while selecting candidates, we already ensured that the original
@@ -361,11 +277,13 @@ class FuseReshape {
361277
// Where:
362278
// - the reshape operation drops the outermost dimension of the operand,
363279
// which is a 3-dim tensor with outermost dimension extent equal to one
364-
// - the reshape result is used by the dot operation
280+
// - the reshape result is used by a dot operation
365281
// - the reshape operation uses the result of a 3-dim load operation on a
366282
// block pointer (transitively) defined by a `make_tensor_ptr` operation
367283
// - the block pointer points to a tensor that has extent equal to 1 on the
368284
// outermost dimension
285+
// - the load operation doesn't have boundary checks on either of the
286+
// dimensions collapsed
369287
bool isCandidate(tt::ReshapeOp reshapeOp) const {
370288
assert(reshapeOp && "Expecting a valid reshape operation");
371289

@@ -384,8 +302,7 @@ class FuseReshape {
384302
return false;
385303
}
386304

387-
// Check whether \p reshapeOp is used by a `dotOp` (directly or
388-
// indirectly).
305+
// Check whether \p reshapeOp is used by a `dotOp`.
389306
auto usedByDotOp = [](tt::ReshapeOp reshapeOp) {
390307
if (!reshapeOp->hasOneUse())
391308
return false;
@@ -405,6 +322,7 @@ class FuseReshape {
405322
if (!usedByDotOp(reshapeOp))
406323
return false;
407324

325+
// The reshape operation uses the result of a load operation.
408326
Operation *defOp = reshapeOp.getSrc().getDefiningOp();
409327
if (!defOp || !isa<tt::LoadOp>(defOp))
410328
return false;
@@ -413,6 +331,8 @@ class FuseReshape {
413331
if (!loadOp->hasOneUse())
414332
return false;
415333

334+
// The load uses a 3-dim block pointer defined by a make_tensor_ptr
335+
// operation.
416336
Type ptrType = loadOp.getPtr().getType();
417337
if (!tt::isTensorPointerType(ptrType))
418338
return false;
@@ -432,14 +352,14 @@ class FuseReshape {
432352
if (order.front() != tensorTy.getRank() - 1)
433353
return false;
434354

435-
// Ensure that the innermost stride is one.
436355
unsigned innermostDimIdx = 0;
437-
for (int i : order) {
438-
if (i == 0)
356+
for (int idx : order) {
357+
if (idx == 0)
439358
break;
440359
++innermostDimIdx;
441360
}
442361

362+
// Ensure that the innermost stride is one.
443363
auto strides = makeTensorPtrOp->getStrides();
444364
Value innermostStride = strides[innermostDimIdx];
445365
if (!innermostStride.getDefiningOp() ||
@@ -451,9 +371,9 @@ class FuseReshape {
451371
if (integerCst.value() != 1)
452372
return false;
453373

454-
// Ensure the load boundary check doesn't check the outermost dimension.
455-
return llvm::none_of(loadOp.getBoundaryCheck(),
456-
[](int val) { return val == 0; });
374+
// Ensure the load operation checks at most the innermost dimension.
375+
return llvm::all_of(loadOp.getBoundaryCheck(),
376+
[&](int idx) { return idx == innermostDimIdx; });
457377
}
458378

459379
// Prune chains that cannot be handled during fusion. For example,

0 commit comments

Comments
 (0)