Skip to content

Commit 0d2235f

Browse files
Fuse reshape with load op (#5268)
This PR introduces a new transformation in the Triton XPU compiler pipeline. The transformation objective is to remove reshape operation on 3-dim block pointers. Example: ``` %desc = tt.make_tensor_ptr %ptr, [%s0, %s1, %s2], [%x, %y, %z], [%o0, %o1, %o2] {order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>> %val = tt.load %desc {boundaryCheck = array<i32: 1, 2>} : !tt.ptr<tensor<1x512x64xf16>> %res = tt.reshape %val : tensor<1x512x64xf16> -> tensor<512x64xf16> ``` becomes: ``` %desc = tt.make_tensor_ptr %ptr, [%s0*%x/%y+%s1, %s2], [%y, %z], [%o0*%x/%y+%o1,%o2] {order = array<i32: 1, 0>} : <tensor<512x64xf16>> %res = tt.load %desc {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<512x64xf16>> ``` --------- Signed-off-by: Ettore Tiotto <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 7cab56d commit 0d2235f

File tree

7 files changed

+809
-1
lines changed

7 files changed

+809
-1
lines changed

bin/RegisterTritonDialects.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9494
mlir::test::registerTestAMDGPUMembarPass();
9595
mlir::test::registerTestTritonAMDGPURangeAnalysis();
9696
mlir::triton::registerConvertTritonToTritonGPUPass();
97-
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
97+
mlir::triton::intel::registerTritonIntelFuseReshape();
9898
mlir::triton::intel::registerTritonIntelRemoveMasks();
99+
mlir::triton::intel::registerTritonIntelTensorDescToBlockPointer();
99100
mlir::triton::registerRelayoutTritonGPUPass();
100101
mlir::triton::gpu::registerAllocateSharedMemoryPass();
101102
mlir::triton::gpu::registerTritonGPUAllocateWarpGroups();
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
// RUN: triton-opt %s -split-input-file -triton-intel-fuse-reshape | FileCheck %s
2+
3+
// COM: tt.load -> tt.reshape -> tt.dot chain, not in a loop.
4+
tt.func public @fuseLoadWithReshape1(%arg0: !tt.ptr<tensor<256x32xbf16>>, %arg1: !tt.ptr<bf16>) {
5+
%c0_i32 = arith.constant 0 : i32
6+
%c1_i32 = arith.constant 1 : i32
7+
%c2_i32 = arith.constant 2 : i32
8+
%c1_i64 = arith.constant 1 : i64
9+
%c4_i64 = arith.constant 4 : i64
10+
%c64_i64 = arith.constant 4 : i64
11+
%c1024_i64 = arith.constant 1024 : i64
12+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
13+
%0 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64, %c1024_i64], [%c1024_i64, %c4_i64, %c1_i64], [%c2_i32, %c1_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x256xbf16>>
14+
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xbf16>>
15+
%3 = tt.load %0 {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x32x256xbf16>>
16+
%4 = tt.reshape %3 : tensor<1x32x256xbf16> -> tensor<32x256xbf16>
17+
%5 = tt.dot %1, %4, %cst, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
18+
tt.return
19+
}
20+
// CHECK-LABEL: fuseLoadWithReshape1
21+
// CHECK-NOT: tt.reshape
22+
// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c4_i64 : i64
23+
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
24+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c4_i64_0 : i64
25+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
26+
// CHECK: [[MUL2:%.*]] = arith.muli %c2_i32, [[TRUNC]] : i32
27+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c1_i32 : i32
28+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [[[ADD1]], %c1024_i64], [%c4_i64, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xbf16>>
29+
// CHECK: [[LOAD_B:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<32x256xbf16>>
30+
// CHECK: tt.dot {{.*}}, [[LOAD_B]], {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
31+
32+
// -----
33+
34+
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop.
35+
// COM: where the 'make_tensor_ptr' result is not loop carried.
36+
tt.func public @fuseLoadWithReshape2(%arg0: !tt.ptr<tensor<32x256xbf16>>, %arg1: !tt.ptr<bf16>) {
37+
%c0_i32 = arith.constant 0 : i32
38+
%c32_i32 = arith.constant 32 : i32
39+
%c1024_i32 = arith.constant 1024 : i32
40+
%c32_i64 = arith.constant 32 : i64
41+
%c1_i64 = arith.constant 1 : i64
42+
%c512_i64 = arith.constant 512 : i64
43+
%c1024_i64 = arith.constant 1024 : i64
44+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32>
45+
%0 = tt.make_tensor_ptr %arg1, [%c512_i64, %c1024_i64, %c32_i64], [%c1024_i64, %c1_i64, %c512_i64], [%c32_i32, %c32_i32, %c0_i32] {order = array<i32: 2, 0, 1>} : <tensor<1x256x32xbf16>>
46+
%res:2 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32) -> (tensor<256x256xf32>, i32) : i32 {
47+
%1 = tt.load %arg0 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xbf16>>
48+
%3 = tt.load %0 {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<1x256x32xbf16>>
49+
%2 = tt.reshape %3 : tensor<1x256x32xbf16> -> tensor<256x32xbf16>
50+
%4 = tt.dot %2, %1, %arg4, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
51+
%5 = arith.addi %arg5, %c32_i32 : i32
52+
scf.yield %4, %5 : tensor<256x256xf32>, i32
53+
}
54+
tt.return
55+
}
56+
// CHECK-LABEL: fuseLoadWithReshape2
57+
// CHECK-NOT: tt.reshape
58+
// CHECK: [[DIV:%.*]] = arith.divui %c1024_i64, %c512_i64 : i64
59+
// CHECK: [[MUL1:%.*]] = arith.muli %c512_i64, [[DIV]] : i64
60+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %c32_i64 : i64
61+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
62+
// CHECK: [[MUL2:%.*]] = arith.muli %c32_i32, [[TRUNC]] : i32
63+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c0_i32 : i32
64+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg1, [%c1024_i64, [[ADD1]]], [%c1_i64, %c512_i64], [%c32_i32, [[ADD2]]] {order = array<i32: 0, 1>} : <tensor<256x32xbf16>>
65+
// CHECK: scf.for
66+
// CHECK: [[LOAD_A:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x32xbf16>>
67+
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xbf16> * tensor<32x256xbf16> -> tensor<256x256xf32>
68+
69+
// -----
70+
71+
// COM: tt.load -> tt.reshape -> tt.dot chain, in a loop
72+
// COM: Where the 'make_tensor_ptr' result is loop carried.
73+
tt.func public @fuseLoadWithReshape3(%a_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %b_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %c_ptr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %M: i32 {tt.divisibility = 16 : i32}, %N: i32 {tt.divisibility = 16 : i32}, %K: i32 {tt.divisibility = 16 : i32}, %stride_am: i32 {tt.divisibility = 16 : i32}, %stride_bk: i32 {tt.divisibility = 16 : i32}, %stride_cm: i32 {tt.divisibility = 16 : i32}) {
74+
%c127_i32 = arith.constant 127 : i32
75+
%c255_i32 = arith.constant 255 : i32
76+
%cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32>
77+
%c32_i32 = arith.constant 32 : i32
78+
%c128_i32 = arith.constant 128 : i32
79+
%c0_i32 = arith.constant 0 : i32
80+
%c1_i64 = arith.constant 1 : i64
81+
%c256_i32 = arith.constant 256 : i32
82+
%c4_i32 = arith.constant 4 : i32
83+
%0 = tt.get_program_id x : i32
84+
%1 = arith.addi %M, %c255_i32 : i32
85+
%2 = arith.divsi %1, %c256_i32 : i32
86+
%3 = arith.addi %N, %c127_i32 : i32
87+
%4 = arith.divsi %3, %c128_i32 : i32
88+
%5 = arith.muli %4, %c4_i32 : i32
89+
%6 = arith.divsi %0, %5 : i32
90+
%7 = arith.muli %6, %c4_i32 : i32
91+
%8 = arith.subi %2, %7 : i32
92+
%9 = arith.minsi %8, %c4_i32 : i32
93+
%10 = arith.remsi %0, %5 : i32
94+
%11 = arith.remsi %10, %9 : i32
95+
%12 = arith.addi %7, %11 : i32
96+
%13 = arith.divsi %10, %9 : i32
97+
%14 = arith.muli %12, %c256_i32 : i32
98+
%15 = arith.extsi %M : i32 to i64
99+
%16 = arith.extsi %K : i32 to i64
100+
%17 = arith.extsi %stride_am : i32 to i64
101+
%18 = tt.make_tensor_ptr %a_ptr, [%c1_i64, %15, %16], [%c1_i64, %17, %c1_i64], [%c0_i32, %c128_i32, %c0_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x256x32xf32>>
102+
%19 = arith.muli %13, %c128_i32 : i32
103+
%20 = arith.extsi %N : i32 to i64
104+
%21 = arith.extsi %stride_bk : i32 to i64
105+
%22 = tt.make_tensor_ptr %b_ptr, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x128xf32>>
106+
%accumulator:3 = scf.for %k = %c0_i32 to %K step %c32_i32 iter_args(%a_block_ptr = %18, %b_block_ptr = %22, %accumulator_0 = %cst) -> (!tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>) : i32 {
107+
%25 = tt.load %a_block_ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x256x32xf32>>
108+
%26 = tt.reshape %25 : tensor<1x256x32xf32> -> tensor<256x32xf32>
109+
%27 = tt.load %b_block_ptr {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x128xf32>>
110+
%28 = tt.dot %26, %27, %cst, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
111+
%29 = arith.addf %accumulator_0, %28 : tensor<256x128xf32>
112+
%30 = tt.advance %a_block_ptr, [%c0_i32, %c0_i32, %c32_i32] : <tensor<1x256x32xf32>>
113+
%31 = tt.advance %b_block_ptr, [%c32_i32, %c0_i32] : <tensor<32x128xf32>>
114+
scf.yield %30, %31, %29 : !tt.ptr<tensor<1x256x32xf32>>, !tt.ptr<tensor<32x128xf32>>, tensor<256x128xf32>
115+
}
116+
%23 = arith.extsi %stride_cm : i32 to i64
117+
%24 = tt.make_tensor_ptr %c_ptr, [%15, %20], [%23, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x128xf32>>
118+
tt.store %24, %accumulator#2 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x128xf32>>
119+
tt.return
120+
}
121+
// CHECK-LABEL: fuseLoadWithReshape3
122+
// CHECK-NOT: tt.reshape
123+
// CHECK: [[EXT_M:%.*]] = arith.extsi %arg3 : i32 to i64
124+
// CHECK: [[DIV:%.*]] = arith.divui %c1_i64, %17 : i64
125+
// CHECK: [[MUL1:%.*]] = arith.muli %c1_i64, [[DIV]] : i64
126+
// CHECK: [[ADD1:%.*]] = arith.addi [[MUL1]], %15 : i64
127+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[DIV]] : i64 to i32
128+
// CHECK: [[MUL2:%.*]] = arith.muli %c0_i32, [[TRUNC]] : i32
129+
// CHECK: [[ADD2:%.*]] = arith.addi [[MUL2]], %c128_i32 : i32
130+
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr %arg0, [[[ADD1]], %16], [%17, %c1_i64], [[[ADD2]], %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf32>>
131+
// CHECK: scf.for {{.*}} = %c0_i32 to {{.*}} step %c32_i32 iter_args([[ARG:%.*]] = [[PTR]]
132+
// CHECK: [[LOAD_A:%.*]] = tt.load [[ARG]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<256x32xf32>>
133+
// CHECK: tt.dot [[LOAD_A]], {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf32> * tensor<32x128xf32> -> tensor<256x128xf32>
134+
// CHECK: tt.advance [[ARG]], [%c0_i32, %c32_i32] : <tensor<256x32xf32>>
135+
136+
// -----
137+
138+
// COM: tt.load -> tt.reshape -> tt.dot chain, in 2 loops.
139+
// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
140+
tt.func public @fuseLoadWithReshape4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
141+
%c0_i32 = arith.constant 0 : i32
142+
%c1_i32 = arith.constant 1 : i32
143+
%c2_i32 = arith.constant 2 : i32
144+
%c32_i32 = arith.constant 32 : i32
145+
%c1_i64 = arith.constant 1 : i64
146+
%c64_i64 = arith.constant 64 : i64
147+
%c256_i64 = arith.constant 256 : i64
148+
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32>
149+
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16>>
150+
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c256_i64, %c64_i64], [%c256_i64, %c64_i64, %c1_i64], [%c0_i32, %c1_i32, %c2_i32] {order = array<i32: 2, 1, 0>} : <tensor<1x32x64xf16>>
151+
%10 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x32xf16>>
152+
%11 = tt.load %10 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16>>
153+
%res1:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
154+
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
155+
%load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>>
156+
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
157+
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
158+
%add = arith.addi %arg4, %c32_i32 : i32
159+
scf.yield %add : i32
160+
}
161+
%res2:1 = scf.for %arg3 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg4 = %arg0) -> (i32) : i32 {
162+
%adv = tt.advance %9, [%arg4, %c0_i32] : <tensor<1x32x64xf16>>
163+
%load = tt.load %adv : !tt.ptr<tensor<1x32x64xf16>>
164+
%reshape = tt.reshape %load : tensor<1x32x64xf16> -> tensor<32x64xf16>
165+
%dot = tt.dot %11, %reshape, %cst, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
166+
%add = arith.addi %arg4, %c32_i32 : i32
167+
scf.yield %add : i32
168+
}
169+
tt.return
170+
}
171+
// CHECK-LABEL: fuseLoadWithReshape4
172+
// CHECK-NOT: tt.reshape
173+
// CHECK: [[DIV1:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
174+
// CHECK: [[MUL11:%.*]] = arith.muli %c1_i64, [[DIV1]] : i64
175+
// CHECK: [[ADD11:%.*]] = arith.addi [[MUL11]], %c256_i64 : i64
176+
// CHECK: [[TRUNC1:%.*]] = arith.trunci [[DIV1]] : i64 to i32
177+
// CHECK: [[MUL21:%.*]] = arith.muli %c0_i32, [[TRUNC1]] : i32
178+
// CHECK: [[ADD21:%.*]] = arith.addi [[MUL21]], %c1_i32 : i32
179+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD11]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD21]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
180+
// CHECK: [[DIV2:%.*]] = arith.divui %c256_i64, %c64_i64 : i64
181+
// CHECK: [[MUL12:%.*]] = arith.muli %c1_i64, [[DIV2]] : i64
182+
// CHECK: [[ADD12:%.*]] = arith.addi [[MUL12]], %c256_i64 : i64
183+
// CHECK: [[TRUNC2:%.*]] = arith.trunci [[DIV2]] : i64 to i32
184+
// CHECK: [[MUL22:%.*]] = arith.muli %c0_i32, [[TRUNC2]] : i32
185+
// CHECK: [[ADD22:%.*]] = arith.addi [[MUL22]], %c1_i32 : i32
186+
// CHECK: [[PTR2:%.*]] = tt.make_tensor_ptr %arg2, [[[ADD12]], %c64_i64], [%c64_i64, %c1_i64], [[[ADD22]], %c2_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16>>
187+
// CHECK: scf.for
188+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<32x64xf16>>
189+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
190+
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
191+
// CHECK: scf.yield
192+
// CHECK: scf.for
193+
// CHECK: [[ADV:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<32x64xf16>>
194+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV]] : !tt.ptr<tensor<32x64xf16>>
195+
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x32xf16> * tensor<32x64xf16> -> tensor<64x64xf32>
196+
// CHECK: scf.yield

third_party/intel/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def make_ttir(mod, metadata, opt):
201201
passes.common.add_cse(pm)
202202
passes.common.add_licm(pm)
203203
intel.passes.ttir.add_remove_masks(pm)
204+
intel.passes.ttir.add_fuse_reshape(pm)
204205
passes.common.add_canonicalizer(pm)
205206
passes.ttir.add_combine(pm)
206207
passes.ttir.add_reorder_broadcast(pm)

third_party/intel/include/Dialect/Triton/Transforms/Passes.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,30 @@ def TritonIntelRemoveMasks
4444
];
4545
}
4646

47+
def TritonIntelFuseReshape
48+
: Pass<"triton-intel-fuse-reshape", "mlir::ModuleOp"> {
49+
let summary = "Fuse a tt.reshape operation with a tt.load operation (block ptrs only)";
50+
51+
let description = [{
52+
This pass attempts to fuse a tt.reshape operation with a tt.load operation.
53+
For example, given:
54+
%ptr = tt.make_tensor_ptr %base_ptr, [%s0, %s1, %s2], [%a, %b, %c], [%x, %y, %z]
55+
{order = array<i32: 2, 1, 0>} : <tensor<1x512x64xf16>>
56+
%load = tt.load %ptr {boundaryCheck = array<i32: 2>} : !tt.ptr<tensor<1x512x64xf16>>
57+
%A = tt.reshape %load : tensor<1x512x64xf16> -> tensor<512x64xf16>
58+
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
59+
60+
The transformation drops the reshape operation, and generates:
61+
%div = %a / %b
62+
%ptr = tt.make_tensor_ptr %base_ptr, [%s0 * %div + %s1, %s2], [%b, %c], [%x * %div + %y, %z]
63+
{order = array<i32: 1, 0>} : <tensor<512x64xf16>>
64+
%A = tt.load %ptr {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<512x64xf16>>
65+
%dot %A, ... : tensor<512x64xf16> x tensor<64x32xf16> -> tensor<512x32xf16>
66+
}];
67+
68+
let dependentDialects = [
69+
"mlir::triton::TritonDialect"
70+
];
71+
}
72+
4773
#endif // TRITON_DIALECT_TRITON_INTEL_TRANSFORMS_PASSES

third_party/intel/lib/Dialect/Triton/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_triton_library(TritonIntelTransforms
2+
FuseReshape.cpp
23
RemoveMasks.cpp
34
TensorDescToBlockPointer.cpp
45

0 commit comments

Comments
 (0)