Skip to content

Commit 02602ad

Browse files
address review comments
1 parent 85f5eb4 commit 02602ad

File tree

3 files changed

+7
-51
lines changed

3 files changed

+7
-51
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
6363
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
6464
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
6565
// CHECK: scf.for
66-
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] : !tt.ptr<tensor<256x32xbf16>>
66+
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>>
6767
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
6868

6969
// -----

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,15 @@ def TritonIntelFuseReshape
5353
For example, given:
5454
%ptr = tt.make_tensor_ptr %base_ptr, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %z]
5555
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
56-
%load = tt.load %ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>>
56+
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
5757
%A = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
5858
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
5959

6060
The transformation drops the reshape operation, and generates:
6161
%div = %a / %b
6262
%ptr = tt.make_tensor_ptr %base_ptr, [%s0 * %div + %s1, %s2], [%b, %c], [%x * %div + %y, %z]
6363
{order = array<i32: 1, 0>} : <tensor<512x64xf16>>
64-
%A = tt.load %ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<512x64xf16>>
64+
%A = tt.load %ptr {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<512x64xf16>>
6565
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
6666
}];
6767

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 4 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,6 @@ namespace mlir::triton::intel {
2424

2525
namespace {
2626

27-
// Transform:
28-
// %one = arith.constant 1 : i64
29-
// %ptr = make_tensor_ptr %q_view, [%q, %q_23, %q_24],
30-
// [%q_25, %q_26, %one], [%offset_5, %offset_1_13, %q_28]
31-
// {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
32-
// %load = tt.load %ptr {boundaryCheck = array<i32: 1, 2>}
33-
// : !tt.ptr<tensor<1x512x64xf16>>
34-
// %a = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
35-
// tt.dot(%a, ...)
36-
// into:
37-
// %one = arith.constant 1 : i64
38-
// %ptr = make_tensor_ptr %q_view, [%q_23, %q_24], [%q_26, %one],
39-
// [%offset_1_13, %offset_5*%q_25+%q_28]
40-
// {order = array<i32: 1, 0>} : <tensor<512x64xf16>>
41-
// %a = tt.load %ptr {boundaryCheck = array<i32: 0, 1>}
42-
// : !tt.ptr<tensor<512x64xf16>>
43-
// tt.dot(%a, ...)
4427
class FuseReshape {
4528
private:
4629
SmallPtrSet<Operation *, 8> cleanUp;
@@ -250,25 +233,10 @@ class FuseReshape {
250233
auto newLoadOp =
251234
cast<tt::LoadOp>(mapping.lookup(static_cast<Operation *>(loadOp)));
252235
ArrayRef<int> boundaryCheck = newLoadOp.getBoundaryCheck();
253-
254-
switch (boundaryCheck.size()) {
255-
case 0:
256-
break;
257-
case 1:
258-
// intentional fall-through
259-
case 2: {
260-
SmallVector<int> newBoundaryCheck;
261-
if ((boundaryCheck[0] - 1) != 0)
262-
newBoundaryCheck.push_back((boundaryCheck[0] - 1));
263-
if (boundaryCheck.size() == 2 && (boundaryCheck[1] - 1) != 0)
264-
newBoundaryCheck.push_back(boundaryCheck[1] - 1);
265-
newLoadOp.setBoundaryCheck(newBoundaryCheck);
266-
} break;
267-
default:
268-
// Note: while selecting candidates, we already ensured that the original
269-
// load's boundary check doesn't check dim zero. So its max rank should
270-
// be 2.
271-
assert(boundaryCheck.size() != 3 && "Unexpected boundary check rank");
236+
for (int idx : boundaryCheck) {
237+
assert(idx == (newInnermostDimIdx + 1) &&
238+
"Unexpected boundary check idx");
239+
newLoadOp.setBoundaryCheck({static_cast<int>(newInnermostDimIdx)});
272240
}
273241
}
274242

@@ -359,18 +327,6 @@ class FuseReshape {
359327
++innermostDimIdx;
360328
}
361329

362-
// Ensure that the innermost stride is one.
363-
auto strides = makeTensorPtrOp->getStrides();
364-
Value innermostStride = strides[innermostDimIdx];
365-
if (!innermostStride.getDefiningOp() ||
366-
!isa<arith::ConstantIntOp>(innermostStride.getDefiningOp()))
367-
return false;
368-
369-
auto integerCst =
370-
cast<arith::ConstantIntOp>(innermostStride.getDefiningOp());
371-
if (integerCst.value() != 1)
372-
return false;
373-
374330
// Ensure the load operation checks at most the innermost dimension.
375331
return llvm::all_of(loadOp.getBoundaryCheck(),
376332
[&](int idx) { return idx == innermostDimIdx; });

0 commit comments

Comments
 (0)