Skip to content

Commit 438044b

Browse files
committed
[RemoveLayoutConversions]: Update index computations
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent b3d9c5a commit 438044b

File tree

2 files changed

+39
-48
lines changed

2 files changed

+39
-48
lines changed

test/Triton/Intel/FuseReshape/fuse-reshape.mlir

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
77
%c2_i32 = arith.constant 2 : i32
88
%c1_i64 = arith.constant 1 : i64
99
%c2_i64 = arith.constant 2 : i64
10-
%c3_i64 = arith.constant 3 : i64
10+
%c4_i64 = arith.constant 4 : i64
1111
%c1024_i64 = arith.constant 1024 : i64
1212
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
13-
%0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c3_i64, %c1024_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
13+
%0 = tt.make_tensor_ptr %arg1, [%c2_i64, %c1_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
1414
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
1515
%3 = tt.load %0 {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x32x256xbf16>>
1616
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
@@ -19,12 +19,14 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
1919
}
2020
// CHECK-LABEL: fuseLoadWithReshape1
2121
// CHECK-NOT: tt.reshape
22-
// CHECK: [[MUL1:%.*]] = arith.muli %c3_i64, %c2_i64 : i64
23-
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
24-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c3_i64 : i64 to i32
25-
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c2_i32 : i32
26-
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
27-
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1_i64, [[ADD1]]], [%c1024_i64, %c1_i64], [%c1_i32, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
22+
// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64
23+
// CHECK: [[MUL1:%.*]] = arith.muli %c2_i64, [[DIV]] : i64
24+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1_i64 : i64
25+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
26+
// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
27+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
28+
29+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
2830
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
2931
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
3032

@@ -34,14 +36,14 @@ tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1:
3436
// COM: where the 'make_tensor_ptr' result is not loop carried.
3537
tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) {
3638
%c0_i32 = arith.constant 0 : i32
37-
%c1_i32 = arith.constant 1 : i32
38-
%c1_i64 = arith.constant 1 : i64
3939
%c32_i32 = arith.constant 32 : i32
4040
%c1024_i32 = arith.constant 1024 : i32
41+
%c32_i64 = arith.constant 32 : i64
42+
%c1_i64 = arith.constant 1 : i64
4143
%c512_i64 = arith.constant 512 : i64
4244
%c1024_i64 = arith.constant 1024 : i64
4345
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
44-
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c1_i64], [%c512_i64, %c1_i64, %c1024_i64], [%c1_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
46+
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c32_i64], [%c1024_i64, %c1_i64, %c512_i64], [%c32_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
4547
%res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
4648
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
4749
%3 = tt.load %0 {boundaryCheck = array<i32: 2, 1>} : !tt.ptr<tensor<1x256x32xbf16>>
@@ -54,19 +56,20 @@ tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1:
5456
}
5557
// CHECK-LABEL: fuseLoadWithReshape2
5658
// CHECK-NOT: tt.reshape
57-
// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, %c512_i64 : i64
58-
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c1024_i64 : i64
59-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c512_i64 : i64 to i32
60-
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c1_i32 : i32
61-
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c32_i32 : i32
62-
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1_i64], [%c1_i64, %c1024_i64], [[[ADD2]], %c0_i32] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
59+
// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c512_i64 : i64
60+
// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, [[DIV]] : i64
61+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c32_i64 : i64
62+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
63+
// CHECK: [[MUL2:%.*]] = arith.muli %c32_i32, [[TRUNC]] : i32
64+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
65+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
6366
// CHECK: scf.for
6467
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1, 0>} : !tt.ptr<tensor<256x32xbf16>>
6568
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
6669

6770
// -----
6871

69-
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
72+
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
7073
// COM: Where the 'make_tensor_ptr' result is loop carried.
7174
tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) {
7275
%c127_i32 = arith.constant 127 : i32
@@ -118,12 +121,13 @@ tt.func public @test_matmul(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
118121
}
119122
// CHECK-LABEL: test_matmul
120123
// CHECK-NOT: tt.reshape
121-
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, %c1_i64 : i64
122-
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %16 : i64
123-
// CHECK: [[TRUNC:%.*]] = arith.trunci %c1_i64 : i64 to i32
124-
// CHECK: [[MUL2:%.*]] = arith.muli [[TRUNC]], %c0_i32 : i32
125-
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
126-
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [%15, [[ADD1]]], [%17, %c1_i64], [%14, [[ADD2]]] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
124+
// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125+
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
126+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
127+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
128+
// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
129+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %14 : i32
130+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
127131
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
128132
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf32>>
129133
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -201,37 +201,24 @@ class FuseReshape {
201201
OperandRange strides = makeTensorPtrOp.getStrides();
202202
OperandRange offsets = makeTensorPtrOp.getOffsets();
203203

204-
#if 0
205-
// order=2,1,0 --> idx = 2 (row major) --> idx we want = 1
206-
// order=2,0,1 --> idx = 1 (column major) --> idx we want == 0
207-
204+
// Collapse the 3-dim tensor into a 2-dim tensor.
205+
// Given a block pointer with:
206+
// shape [s0, s1, s2]
207+
// stride [a, b, c]
208+
// offset [x, y, z]
209+
// We create a block pinter with:
210+
// shape [s0 * a / b + s1, s2]
211+
// stride [b, c]
212+
// offset [x * a / b + y, z]
208213
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
209-
newShape[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
210-
loc, builder.create<arith::MulIOp>(loc, strides[0], shapes[0]),
211-
newShape[innermostDimIdx - 1]);
212214
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
213215
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
214-
newOffsets[innermostDimIdx - 1] = builder.create<arith::AddIOp>(
215-
loc,
216-
builder.create<arith::MulIOp>(
217-
loc,
218-
builder.create<arith::TruncIOp>(loc, offsets[0].getType(),
219-
strides[0]),
220-
offsets[0]),
221-
newOffsets[innermostDimIdx - 1]);
222-
#else
223-
// order=2,1,0 --> idx = 2 (row major) --> idx we want = 0
224-
// order=2,0,1 --> idx = 1 (column major) --> idx we want == 1
225216

226217
unsigned newInnermostDimIdx = (innermostDimIdx - 1);
227218
unsigned newOutermostDimIdx = !newInnermostDimIdx;
228-
229-
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
230-
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
231-
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());
232-
233219
auto div = builder.create<arith::DivUIOp>(loc, strides[0],
234220
newStrides[newOutermostDimIdx]);
221+
235222
newShape[newOutermostDimIdx] = builder.create<arith::AddIOp>(
236223
loc, builder.create<arith::MulIOp>(loc, shapes[0], div),
237224
newShape[newOutermostDimIdx]);
@@ -241,7 +228,7 @@ class FuseReshape {
241228
loc, offsets[0],
242229
builder.create<arith::TruncIOp>(loc, offsets[0].getType(), div)),
243230
newOffsets[newOutermostDimIdx]);
244-
#endif
231+
245232
Value ptr = builder.create<tt::MakeTensorPtrOp>(
246233
loc, newPtrType, makeTensorPtrOp.getBase(), newShape, newStrides,
247234
newOffsets,

0 commit comments

Comments
 (0)