Skip to content

Commit 79199c6

Browse files
committed
Address code review comments
Signed-off-by: Ettore Tiotto <[email protected]>
1 parent 438044b commit 79199c6

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9494
mlir::test::registerTestAMDGPUMembarPass();
9595
mlir::test::registerTestTritonAMDGPURangeAnalysis();
9696
mlir::triton::registerConvertTritonToTritonGPUPass();
97-
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
98-
mlir::triton::intel::registerTritonIntelRemoveMasks();
9997
mlir::triton::intel::registerTritonIntelFuseReshape();
98+
mlir::triton::intel::registerTritonIntelRemoveMasks();
99+
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
100100
mlir::triton::registerRelayoutTritonGPUPass();
101101
mlir::triton::gpu::registerAllocateSharedMemoryPass();
102102
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,23 @@ def TritonIntelRemoveMasks
4646

4747
def TritonIntelFuseReshape
4848
: Pass<"triton-intel-fuse-reshape", "mlir::ModuleOp"> {
49-
let summary = "Fuse a tt.reshape operation with a tt.load operation";
49+
let summary = "Fuse a tt.reshape operation with a tt.load operation (block ptrs only)";
5050

5151
let description = [{
52-
This pass attempts to fuse a tt.reshape operation with a tt.load operation using a block pointer.
52+
This pass attempts to fuse a tt.reshape operation with a tt.load operation.
5353
For example, given:
54-
%q_27 = arith.constant 1 : i64
55-
%ptr = tt.make_tensor_ptr %q_view, [%q, %q_23, %q_24], [%q_25, %q_26, %q_27], [%offset_5, %offset_1_13, %q_28]
54+
%ptr = tt.make_tensor_ptr %base_ptr, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %z]
5655
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
5756
%load = tt.load %ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>>
58-
%reshape = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
59-
The transformation would drops the reshape operation and adjust the make_tensor_ptr operation as follows:
60-
%q_27 = arith.constant 1 : i64
61-
%ptr = tt.make_tensor_ptr %q_view, [%q_23, %q_24], [%q_26, %q_27], [%offset_1_13, %offset_5*%q_25+%q_28]
57+
%A = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
58+
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
59+
60+
The transformation drops the reshape operation, and generates:
61+
%div = %a / %b
62+
%ptr = tt.make_tensor_ptr %base_ptr, [%s0 * %div + %s1, %s2], [%b, %c], [%x * %div + %y, %z]
6263
{order = array<i32: 1, 0>} : <tensor<512x64xf16>>
63-
%load = tt.load %ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<512x64xf16>>
64+
%A = tt.load %ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<512x64xf16>>
65+
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
6466
}];
6567

6668
let dependentDialects = [

third_party/intel/lib/Dialect/Triton/Transforms/FuseReshape.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ class FuseReshape {
9393
// Remove operations that are no longer used.
9494
if (!cleanUp.empty())
9595
tt::intel::eraseOperations(cleanUp);
96-
97-
assert(succeeded(verify(moduleOp)) && "Module verification failed");
9896
}
9997

10098
private:
@@ -165,7 +163,7 @@ class FuseReshape {
165163
isa<tt::MakeTensorPtrOp>(chain.getStart()) &&
166164
"Expecting 'chain' to be rooted by a 'tt.make_tensor_ptr' operation");
167165
assert(isa<tt::ReshapeOp>(chain.getEnd()) &&
168-
"Expecting 'chain' to be terminated by a 'tt.rehape' operation");
166+
"Expecting 'chain' to be terminated by a 'tt.reshape' operation");
169167

170168
auto makeTensorPtrOp = cast<tt::MakeTensorPtrOp>(chain.getStart());
171169
auto reshapeOp = cast<tt::ReshapeOp>(chain.getEnd());
@@ -202,14 +200,16 @@ class FuseReshape {
202200
OperandRange offsets = makeTensorPtrOp.getOffsets();
203201

204202
// Collapse the 3-dim tensor into a 2-dim tensor.
205-
// Given a block pointer with:
203+
// Given a make_tensor_ptr with:
206204
// shape [s0, s1, s2]
207205
// stride [a, b, c]
208206
// offset [x, y, z]
209-
// We create a block pinter with:
207+
// order [2, 1, 0]
208+
// We create a make_tensor_ptr with:
210209
// shape [s0 * a / b + s1, s2]
211210
// stride [b, c]
212211
// offset [x * a / b + y, z]
212+
// order [1, 0]
213213
SmallVector<Value> newShape(makeTensorPtrOp.getShape().drop_front());
214214
SmallVector<Value> newStrides(makeTensorPtrOp.getStrides().drop_front());
215215
SmallVector<Value> newOffsets(makeTensorPtrOp.getOffsets().drop_front());

0 commit comments

Comments
 (0)