Skip to content

Commit c1fb63a

Browse files
Remove tritongen lowering from blockptr_store.mlir (#4340)
The purpose of `blockptr_store.mlir` is for testing the `tt.store` lowering, and removing `--convert-tritongen-to-llvm` can simplify the lit test, the lowering of tritongen operations are covered in other lit tests. Signed-off-by: Whitney Tsang <[email protected]>
1 parent ce49a59 commit c1fb63a

File tree

1 file changed

+14
-45
lines changed

1 file changed

+14
-45
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 14 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

3-
// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return}
43
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
54
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
65
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
@@ -39,24 +38,23 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
3938
// CHECK: llvm.mul %[[VAL_1]], %[[CST_16]] : i32
4039
// CHECK: llvm.mlir.undef : vector<8xf16>
4140
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
42-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}}
41+
// CHECK: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
4342
// CHECK: llvm.mlir.undef : vector<8xf16>
4443
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
45-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}}
44+
// CHECK: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
4645
// CHECK: llvm.mlir.undef : vector<8xf16>
4746
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
48-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}}
47+
// CHECK: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
4948
// CHECK: llvm.mlir.undef : vector<8xf16>
5049
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
51-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i{{.*}}
50+
// CHECK: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
5251
tt.store %13, %12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xf16, #dpas>>
5352
tt.return
5453
}
5554
}
5655

5756
// -----
5857

59-
// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return}
6058
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
6159
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
6260
// CHECK-LABEL: llvm.func spir_kernelcc @dpas_layout_2d_store_rep_cluster_4_2(
@@ -167,7 +165,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
167165
// CHECK: %[[VAL_160:.*]] = llvm.extractvalue %[[VAL_71]][63]
168166

169167
// CHECK: %[[HEIGHT_i32:.*]] = llvm.trunc %[[HEIGHT_i64]] : i64 to i32
170-
// CHECK: %[[WIDTH_i32:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
168+
// CHECK: %[[baseHeight:.*]] = llvm.trunc %[[WIDTH_i64]] : i64 to i32
171169
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
172170
// CHECK: %[[baseWidth:.*]] = llvm.mul %[[HEIGHT_i32]], %[[CST_2]] : i32
173171
// CHECK: %[[basePitch:.*]] = llvm.mul %[[ROW_STRIDE_i32]], %[[CST_2]] : i32
@@ -206,15 +204,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
206204
// CHECK: %[[VAL_197:.*]] = llvm.bitcast %[[VAL_196]] : vector<8xf16> to vector<8xi16>
207205
// CHECK: %[[VAL_198:.*]] = llvm.trunc %[[offsetY]] : i32 to i32
208206
// CHECK: %[[VAL_199:.*]] = llvm.trunc %[[offsetX]] : i32 to i32
209-
// CHECK: %[[VAL_200:.*]] = llvm.mlir.constant(8 : i32) : i32
210-
// CHECK: %[[VAL_201:.*]] = llvm.alloca %[[VAL_200]] x i16 : (i32) -> !llvm.ptr
211-
// CHECK: llvm.store %[[VAL_197]], %[[VAL_201]] : vector<8xi16>, !llvm.ptr
212-
// CHECK: %[[VAL_202:.*]] = llvm.mlir.constant(1 : i32) : i32
213-
// CHECK: %[[VAL_203:.*]] = llvm.mlir.constant(0 : i32) : i32
214-
// CHECK: %[[VAL_204:.*]] = llvm.mlir.undef : vector<2xi32>
215-
// CHECK: %[[VAL_205:.*]] = llvm.insertelement %{{.*}}, %[[VAL_204]]{{\[}}%[[VAL_203]] : i32] : vector<2xi32>
216-
// CHECK: %[[VAL_206:.*]] = llvm.insertelement %[[VAL_198]], %[[VAL_205]]{{\[}}%[[VAL_202]] : i32] : vector<2xi32>
217-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_201]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], %[[VAL_206]])
207+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], %[[VAL_199]], %[[VAL_198]], %[[VAL_197]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
218208

219209
// COM: replica [0, 1]
220210
// CHECK: %[[VAL_207:.*]] = llvm.mlir.constant(16 : i32) : i32
@@ -229,10 +219,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
229219
// CHECK: %[[VAL_223:.*]] = llvm.insertelement %[[VAL_111]], %[[VAL_221]]{{\[}}{{.*}} : i32] : vector<8xf16>
230220
// CHECK: %[[VAL_225:.*]] = llvm.insertelement %[[VAL_112]], %[[VAL_223]]{{\[}}{{.*}} : i32] : vector<8xf16>
231221
// CHECK: %[[VAL_226:.*]] = llvm.bitcast %[[VAL_225]] : vector<8xf16> to vector<8xi16>
232-
// CHECK: %[[VAL_229:.*]] = llvm.mlir.constant(8 : i32) : i32
233-
// CHECK: %[[VAL_230:.*]] = llvm.alloca %[[VAL_229]] x i16 : (i32) -> !llvm.ptr
234-
// CHECK: llvm.store %[[VAL_226]], %[[VAL_230]] : vector<8xi16>, !llvm.ptr
235-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_230]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
222+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_226]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
236223

237224
// COM: replica [1, 0]
238225
// CHECK: %[[VAL_236:.*]] = llvm.mlir.constant(8 : i32) : i32
@@ -249,10 +236,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
249236
// CHECK: %[[VAL_254:.*]] = llvm.insertelement %[[VAL_119]], %[[VAL_252]]{{\[}}{{.*}} : i32] : vector<8xf16>
250237
// CHECK: %[[VAL_256:.*]] = llvm.insertelement %[[VAL_120]], %[[VAL_254]]{{\[}}{{.*}} : i32] : vector<8xf16>
251238
// CHECK: %[[VAL_257:.*]] = llvm.bitcast %[[VAL_256]] : vector<8xf16> to vector<8xi16>
252-
// CHECK: %[[VAL_260:.*]] = llvm.mlir.constant(8 : i32) : i32
253-
// CHECK: %[[VAL_261:.*]] = llvm.alloca %[[VAL_260]] x i16 : (i32) -> !llvm.ptr
254-
// CHECK: llvm.store %[[VAL_257]], %[[VAL_261]] : vector<8xi16>, !llvm.ptr
255-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_261]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
239+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_257]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
256240

257241
// COM: replica [1, 1]
258242
// CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(16 : i32) : i32
@@ -267,10 +251,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
267251
// CHECK: %[[VAL_283:.*]] = llvm.insertelement %[[VAL_127]], %[[VAL_281]]{{\[}}{{.*}} : i32] : vector<8xf16>
268252
// CHECK: %[[VAL_285:.*]] = llvm.insertelement %[[VAL_128]], %[[VAL_283]]{{\[}}{{.*}} : i32] : vector<8xf16>
269253
// CHECK: %[[VAL_286:.*]] = llvm.bitcast %[[VAL_285]] : vector<8xf16> to vector<8xi16>
270-
// CHECK: %[[VAL_289:.*]] = llvm.mlir.constant(8 : i32) : i32
271-
// CHECK: %[[VAL_290:.*]] = llvm.alloca %[[VAL_289]] x i16 : (i32) -> !llvm.ptr
272-
// CHECK: llvm.store %[[VAL_286]], %[[VAL_290]] : vector<8xi16>, !llvm.ptr
273-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_290]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
254+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_286]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
274255

275256
// COM: replica [2, 0]
276257
// CHECK: %[[VAL_296:.*]] = llvm.mlir.constant(16 : i32) : i32
@@ -287,10 +268,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
287268
// CHECK: %[[VAL_314:.*]] = llvm.insertelement %[[VAL_135]], %[[VAL_312]]{{\[}}{{.*}} : i32] : vector<8xf16>
288269
// CHECK: %[[VAL_316:.*]] = llvm.insertelement %[[VAL_136]], %[[VAL_314]]{{\[}}{{.*}} : i32] : vector<8xf16>
289270
// CHECK: %[[VAL_317:.*]] = llvm.bitcast %[[VAL_316]] : vector<8xf16> to vector<8xi16>
290-
// CHECK: %[[VAL_320:.*]] = llvm.mlir.constant(8 : i32) : i32
291-
// CHECK: %[[VAL_321:.*]] = llvm.alloca %[[VAL_320]] x i16 : (i32) -> !llvm.ptr
292-
// CHECK: llvm.store %[[VAL_317]], %[[VAL_321]] : vector<8xi16>, !llvm.ptr
293-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_321]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
271+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_317]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
294272

295273
// COM: replica [2, 1]
296274
// CHECK: %[[VAL_327:.*]] = llvm.mlir.constant(16 : i32) : i32
@@ -305,10 +283,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
305283
// CHECK: %[[VAL_343:.*]] = llvm.insertelement %[[VAL_143]], %[[VAL_341]]{{\[}}{{.*}} : i32] : vector<8xf16>
306284
// CHECK: %[[VAL_345:.*]] = llvm.insertelement %[[VAL_144]], %[[VAL_343]]{{\[}}{{.*}} : i32] : vector<8xf16>
307285
// CHECK: %[[VAL_346:.*]] = llvm.bitcast %[[VAL_345]] : vector<8xf16> to vector<8xi16>
308-
// CHECK: %[[VAL_349:.*]] = llvm.mlir.constant(8 : i32) : i32
309-
// CHECK: %[[VAL_350:.*]] = llvm.alloca %[[VAL_349]] x i16 : (i32) -> !llvm.ptr
310-
// CHECK: llvm.store %[[VAL_346]], %[[VAL_350]] : vector<8xi16>, !llvm.ptr
311-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_350]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
286+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_346]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
312287

313288
// COM: replica [3, 0]
314289
// CHECK: %[[VAL_356:.*]] = llvm.mlir.constant(24 : i32) : i32
@@ -325,10 +300,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
325300
// CHECK: %[[VAL_374:.*]] = llvm.insertelement %[[VAL_151]], %[[VAL_372]]{{\[}}{{.*}} : i32] : vector<8xf16>
326301
// CHECK: %[[VAL_376:.*]] = llvm.insertelement %[[VAL_152]], %[[VAL_374]]{{\[}}{{.*}} : i32] : vector<8xf16>
327302
// CHECK: %[[VAL_377:.*]] = llvm.bitcast %[[VAL_376]] : vector<8xf16> to vector<8xi16>
328-
// CHECK: %[[VAL_380:.*]] = llvm.mlir.constant(8 : i32) : i32
329-
// CHECK: %[[VAL_381:.*]] = llvm.alloca %[[VAL_380]] x i16 : (i32) -> !llvm.ptr
330-
// CHECK: llvm.store %[[VAL_377]], %[[VAL_381]] : vector<8xi16>, !llvm.ptr
331-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_381]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
303+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_377]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
332304

333305
// COM: replica [3, 1]
334306
// CHECK: %[[VAL_387:.*]] = llvm.mlir.constant(16 : i32) : i32
@@ -343,10 +315,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
343315
// CHECK: %[[VAL_403:.*]] = llvm.insertelement %[[VAL_159]], %[[VAL_401]]{{\[}}{{.*}} : i32] : vector<8xf16>
344316
// CHECK: %[[VAL_405:.*]] = llvm.insertelement %[[VAL_160]], %[[VAL_403]]{{\[}}{{.*}} : i32] : vector<8xf16>
345317
// CHECK: %[[VAL_406:.*]] = llvm.bitcast %[[VAL_405]] : vector<8xf16> to vector<8xi16>
346-
// CHECK: %[[VAL_409:.*]] = llvm.mlir.constant(8 : i32) : i32
347-
// CHECK: %[[VAL_410:.*]] = llvm.alloca %[[VAL_409]] x i16 : (i32) -> !llvm.ptr
348-
// CHECK: llvm.store %[[VAL_406]], %[[VAL_410]] : vector<8xi16>, !llvm.ptr
349-
// CHECK: llvm.call spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i({{.*}}, %[[VAL_410]], %[[BASE_PTR]], %{{.*}}, %[[WIDTH_i32]], %[[basePitch]], {{.*}})
318+
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_406]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
350319

351320
tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x32xf16, #dpas>>
352321
tt.return

0 commit comments

Comments
 (0)