@@ -46,21 +46,23 @@ def TritonIntelRemoveMasks
4646
4747def TritonIntelFuseReshape
4848 : Pass<"triton-intel-fuse-reshape", "mlir::ModuleOp"> {
49- let summary = "Fuse a tt.reshape operation with a tt.load operation";
49+ let summary = "Fuse a tt.reshape operation with a tt.load operation (block ptrs only) ";
5050
5151 let description = [{
52- This pass attempts to fuse a tt.reshape operation with a tt.load operation using a block pointer .
52+ This pass attempts to fuse a tt.reshape operation with a tt.load operation.
5353 For example, given:
54- %q_27 = arith.constant 1 : i64
55- %ptr = tt.make_tensor_ptr %q_view, [%q, %q_23, %q_24], [%q_25, %q_26, %q_27], [%offset_5, %offset_1_13, %q_28]
54+ %ptr = tt.make_tensor_ptr %base_ptr, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %z]
5655 {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
5756 %load = tt.load %ptr {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>>
58- %reshape = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
59- The transformation would drops the reshape operation and adjust the make_tensor_ptr operation as follows:
60- %q_27 = arith.constant 1 : i64
61- %ptr = tt.make_tensor_ptr %q_view, [%q_23, %q_24], [%q_26, %q_27], [%offset_1_13, %offset_5*%q_25+%q_28]
57+ %A = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
58+ %dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
59+
60+ The transformation drops the reshape operation, and generates:
61+ %div = %a / %b
62+ %ptr = tt.make_tensor_ptr %base_ptr, [%s0 * %div + %s1, %s2], [%b, %c], [%x * %div + %y, %z]
6263 {order = array<i32: 1, 0>} : <tensor<512x64xf16>>
63- %load = tt.load %ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<512x64xf16>>
64+ %A = tt.load %ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<512x64xf16>>
65+ %dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
6466 }];
6567
6668 let dependentDialects = [
0 commit comments