Skip to content

Commit c860a38

Browse files
authored
[LoadStoreOpToLLVM] Fix issue of generating invalid 2D block IO operations with pitch < width. (#5555)
When lowering Block IO for regular pointers, the compiler may generate 2D block IO operations whose base address is not 64-byte aligned. This misalignment can lead TritonGen to emit code where width becomes smaller than pitch in order to adjust for the offset from the unaligned base address. To prevent this, we shift each tile’s base address back to the block’s aligned base address. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 8fdf504 commit c860a38

File tree

3 files changed

+109
-59
lines changed

3 files changed

+109
-59
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.t
2323

2424
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
2525
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
26-
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
26+
// ALL-LAYOUT: %[[VAL_132:.*]] = llvm.mlir.constant(0 : i32) : i32
2727
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
2828
// ALL-LAYOUT: llvm.mlir.undef : vector<4xi8>
2929
// ALL-LAYOUT-COUNT-4: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<4xi8>
30+
// ALL-LAYOUT: %[[VAL_155:.*]] = llvm.mlir.constant(1 : i32) : i32
31+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.udiv %[[VAL_132]], %[[VAL_155]] : i32
3032
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 8, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
3133
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #dot_a>>
3234
// ALL-LAYOUT-COUNT-63: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 8, tile_width = 8, tile_height = 8, v_blocks = 1, cache_control = Default}
@@ -59,10 +61,12 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.t
5961

6062
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
6163
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
62-
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
64+
// ALL-LAYOUT: %[[VAL_132:.*]] = llvm.mlir.constant(0 : i32) : i32
6365
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
6466
// ALL-LAYOUT: llvm.mlir.undef : vector<8xi8>
6567
// ALL-LAYOUT-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xi8>
68+
// ALL-LAYOUT: %[[VAL_155:.*]] = llvm.mlir.constant(1 : i32) : i32
69+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.udiv %[[VAL_132]], %[[VAL_155]] : i32
6670
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 8, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
6771
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #dot_b>>
6872
// ALL-LAYOUT-COUNT-63: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 8, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
@@ -95,10 +99,12 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.t
9599

96100
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
97101
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
98-
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
102+
// ALL-LAYOUT: %[[VAL_132:.*]] = llvm.mlir.constant(0 : i32) : i32
99103
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
100104
// ALL-LAYOUT: llvm.mlir.undef : vector<16xi8>
101105
// ALL-LAYOUT-COUNT-16: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<16xi8>
106+
// ALL-LAYOUT: %[[VAL_155:.*]] = llvm.mlir.constant(2 : i32) : i32
107+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.udiv %[[VAL_132]], %[[VAL_155]] : i32
102108
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 8, v_blocks = 1, cache_control = Default}
103109
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #slice>>
104110
// ALL-LAYOUT-COUNT-31: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 8, v_blocks = 1, cache_control = Default}
@@ -130,10 +136,12 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 16 : i32, "ttg.t
130136

131137
// ALL-LAYOUT: %[[OFFSET:.*]] = llvm.add %[[OFF_0]], {{.*}} : i32
132138
// ALL-LAYOUT: %[[BASE:.*]] = llvm.getelementptr %[[BASE_PTR]]{{.*}} : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
133-
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
139+
// ALL-LAYOUT: %[[VAL_132:.*]] = llvm.mlir.constant(0 : i32) : i32
134140
// ALL-LAYOUT: %[[OFFSET_Y:.*]] = llvm.select {{.*}}, %[[OFFSET]], %[[HEIGHT]] : i1, i32
135141
// ALL-LAYOUT: llvm.mlir.undef : vector<8xi8>
136142
// ALL-LAYOUT-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xi8>
143+
// ALL-LAYOUT: %[[VAL_155:.*]] = llvm.mlir.constant(2 : i32) : i32
144+
// ALL-LAYOUT: %[[OFFSET_X:.*]] = llvm.udiv %[[VAL_132]], %[[VAL_155]] : i32
137145
// ALL-LAYOUT: triton_gen.2Dblockstore {{.*}}, %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 4, v_blocks = 1, cache_control = Default}
138146
tt.store %0, %cst {ttig.block_io = "row_major", boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<256x64xi8, #blocked>>
139147
// ALL-LAYOUT-COUNT-7: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 32, tile_height = 4, v_blocks = 1, cache_control = Default}
@@ -217,7 +225,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
217225
// COM: When boundary check is absent:
218226
// CHECK: %[[baseWidth:.*]] = llvm.mlir.constant(64 : i32)
219227
// CHECK: %[[base1:.*]] = llvm.getelementptr %[[base]][%[[OFFSET_X]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, f16
220-
// CHECK: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
228+
// CHECK: %[[VAL_132:.*]] = llvm.mlir.constant(0 : i32) : i32
221229
// CHECK: %[[baseHeight:.*]] = llvm.mlir.constant(8 : i32)
222230
// CHECK: %[[OFF:.*]] = llvm.mul %[[OFFSET_Y]], %[[PITCH]] : i32
223231
// CHECK: %[[base:.*]] = llvm.getelementptr %[[base1]][%[[OFF]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
@@ -227,7 +235,8 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
227235
// CHECK-COUNT-7: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
228236
// CHECK: %[[VAL0:.*]] = llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
229237
// CHECK: %[[VAL:.*]] = llvm.bitcast %[[VAL0]] : vector<8xf16> to vector<8xi16>
230-
238+
// CHECK: %[[VAL_155:.*]] = llvm.mlir.constant(1 : i32) : i32
239+
// CHECK: %[[OFFSET_X:.*]] = llvm.udiv %[[VAL_132]], %[[VAL_155]] : i32
231240
// CHECK: triton_gen.2Dblockstore %[[base]], %[[baseWidth]], %[[baseHeight]], %[[PITCH]], %[[OFFSET_X]], %[[OFFSET_Y]], %[[VAL]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
232241
// CHECK-COUNT-3: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
233242

@@ -304,10 +313,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
304313
// CHECK: %[[VAL_182:.*]] = llvm.xor %[[VAL_168]], %[[VAL_181]] : i32
305314
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_182]] : i32
306315
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_177]] : i32
307-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
308-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
309316
// CHECK: llvm.mlir.undef : vector<8xf16>
310317
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
318+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
319+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
311320
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
312321

313322
// COM: replica [0, 1]
@@ -331,10 +340,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
331340
// CHECK: %[[VAL_223:.*]] = llvm.xor %[[VAL_209]], %[[VAL_222]] : i32
332341
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_223]] : i32
333342
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_218]] : i32
334-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
335-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
336343
// CHECK: llvm.mlir.undef : vector<8xf16>
337344
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
345+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
346+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
338347
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
339348

340349
// COM: replica [1, 0]
@@ -358,10 +367,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
358367
// CHECK: %[[VAL_264:.*]] = llvm.xor %[[VAL_249]], %[[VAL_263]] : i32
359368
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_264]] : i32
360369
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_259]] : i32
361-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
362-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
363370
// CHECK: llvm.mlir.undef : vector<8xf16>
364371
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
372+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
373+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
365374
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
366375

367376
// COM: replica [1, 1]
@@ -385,10 +394,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
385394
// CHECK: %[[VAL_306:.*]] = llvm.xor %[[VAL_292]], %[[VAL_305]] : i32
386395
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_306]] : i32
387396
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_301]] : i32
388-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
389-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
390397
// CHECK: llvm.mlir.undef : vector<8xf16>
391398
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
399+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
400+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
392401
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
393402

394403
// COM: replica [2, 0]
@@ -412,10 +421,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
412421
// CHECK: %[[VAL_347:.*]] = llvm.xor %[[VAL_332]], %[[VAL_346]] : i32
413422
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_347]] : i32
414423
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_342]] : i32
415-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
416-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
417424
// CHECK: llvm.mlir.undef : vector<8xf16>
418425
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
426+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
427+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
419428
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
420429

421430
// COM: replica [2, 1]
@@ -439,10 +448,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
439448
// CHECK: %[[VAL_389:.*]] = llvm.xor %[[VAL_375]], %[[VAL_388]] : i32
440449
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_389]] : i32
441450
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_384]] : i32
442-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
443-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
444451
// CHECK: llvm.mlir.undef : vector<8xf16>
445452
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
453+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
454+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
446455
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
447456

448457
// COM: replica [3, 0]
@@ -466,10 +475,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
466475
// CHECK: %[[VAL_430:.*]] = llvm.xor %[[VAL_415]], %[[VAL_429]] : i32
467476
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_430]] : i32
468477
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_425]] : i32
469-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
470-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
471478
// CHECK: llvm.mlir.undef : vector<8xf16>
472479
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
480+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
481+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
473482
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
474483

475484
// COM: replica [3, 1]
@@ -493,10 +502,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
493502
// CHECK: %[[VAL_472:.*]] = llvm.xor %[[VAL_458]], %[[VAL_471]] : i32
494503
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_472]] : i32
495504
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_467]] : i32
496-
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
497-
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
498505
// CHECK: llvm.mlir.undef : vector<8xf16>
499506
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
507+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
508+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
500509
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
501510

502511
tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dpas>>

0 commit comments

Comments
 (0)