Skip to content

Commit 6a95135

Browse files
authored
Fix issue in block store lowering when packed type are used. (#4936)
This PR fixes an issue in the block store lowering process when dealing with packed types in the Intel GPU backend. The fix addresses a mismatch between how offsets are calculated for linear layouts versus what 2D block I/O operations expect. --------- Signed-off-by: Lu,Chengjun <[email protected]>
1 parent ad0e8d9 commit 6a95135

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
292292
// CHECK: %[[VAL_189:.*]] = llvm.xor %[[VAL_185]], %[[VAL_188]] : i32
293293
// CHECK: %[[VAL_190:.*]] = llvm.mlir.constant(0 : i32) : i32
294294
// CHECK: %[[VAL_191:.*]] = llvm.xor %[[VAL_185]], %[[VAL_190]] : i32
295-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_191]] : i32
295+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_191]] : i32
296296
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_189]] : i32
297+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
298+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
297299
// CHECK: llvm.mlir.undef : vector<8xf16>
298300
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
299301
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -309,8 +311,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
309311
// CHECK: %[[VAL_216:.*]] = llvm.xor %[[VAL_210]], %[[VAL_215]] : i32
310312
// CHECK: %[[VAL_217:.*]] = llvm.mlir.constant(0 : i32) : i32
311313
// CHECK: %[[VAL_218:.*]] = llvm.xor %[[VAL_211]], %[[VAL_217]] : i32
312-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_218]] : i32
314+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_218]] : i32
313315
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_216]] : i32
316+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
317+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
314318
// CHECK: llvm.mlir.undef : vector<8xf16>
315319
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
316320
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -326,8 +330,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
326330
// CHECK: %[[VAL_240:.*]] = llvm.xor %[[VAL_236]], %[[VAL_239]] : i32
327331
// CHECK: %[[VAL_241:.*]] = llvm.mlir.constant(0 : i32) : i32
328332
// CHECK: %[[VAL_242:.*]] = llvm.xor %[[VAL_235]], %[[VAL_241]] : i32
329-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_242]] : i32
333+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_242]] : i32
330334
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_240]] : i32
335+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
336+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
331337
// CHECK: llvm.mlir.undef : vector<8xf16>
332338
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
333339
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -343,8 +349,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
343349
// CHECK: %[[VAL_266:.*]] = llvm.xor %[[VAL_261]], %[[VAL_265]] : i32
344350
// CHECK: %[[VAL_267:.*]] = llvm.mlir.constant(0 : i32) : i32
345351
// CHECK: %[[VAL_268:.*]] = llvm.xor %[[VAL_262]], %[[VAL_267]] : i32
346-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_268]] : i32
352+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_268]] : i32
347353
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_266]] : i32
354+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
355+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
348356
// CHECK: llvm.mlir.undef : vector<8xf16>
349357
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
350358
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -360,8 +368,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
360368
// CHECK: %[[VAL_291:.*]] = llvm.xor %[[VAL_287]], %[[VAL_290]] : i32
361369
// CHECK: %[[VAL_292:.*]] = llvm.mlir.constant(0 : i32) : i32
362370
// CHECK: %[[VAL_293:.*]] = llvm.xor %[[VAL_286]], %[[VAL_292]] : i32
363-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_293]] : i32
371+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_293]] : i32
364372
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_291]] : i32
373+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
374+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
365375
// CHECK: llvm.mlir.undef : vector<8xf16>
366376
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
367377
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -377,8 +387,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
377387
// CHECK: %[[VAL_317:.*]] = llvm.xor %[[VAL_312]], %[[VAL_316]] : i32
378388
// CHECK: %[[VAL_318:.*]] = llvm.mlir.constant(0 : i32) : i32
379389
// CHECK: %[[VAL_319:.*]] = llvm.xor %[[VAL_313]], %[[VAL_318]] : i32
380-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_319]] : i32
390+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_319]] : i32
381391
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_317]] : i32
392+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
393+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
382394
// CHECK: llvm.mlir.undef : vector<8xf16>
383395
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
384396
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -394,8 +406,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
394406
// CHECK: %[[VAL_342:.*]] = llvm.xor %[[VAL_338]], %[[VAL_341]] : i32
395407
// CHECK: %[[VAL_343:.*]] = llvm.mlir.constant(0 : i32) : i32
396408
// CHECK: %[[VAL_344:.*]] = llvm.xor %[[VAL_337]], %[[VAL_343]] : i32
397-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_344]] : i32
409+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_344]] : i32
398410
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_342]] : i32
411+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
412+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
399413
// CHECK: llvm.mlir.undef : vector<8xf16>
400414
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
401415
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
@@ -411,8 +425,10 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
411425
// CHECK: %[[VAL_368:.*]] = llvm.xor %[[VAL_363]], %[[VAL_367]] : i32
412426
// CHECK: %[[VAL_369:.*]] = llvm.mlir.constant(0 : i32) : i32
413427
// CHECK: %[[VAL_370:.*]] = llvm.xor %[[VAL_364]], %[[VAL_369]] : i32
414-
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[OFF_1]], %[[VAL_370]] : i32
428+
// CHECK: %[[ADD:.*]] = llvm.add %[[OFF_1]], %[[VAL_370]] : i32
415429
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[OFF_0]], %[[VAL_368]] : i32
430+
// CHECK: %[[NUM_PACKED_VALS:.*]] = llvm.mlir.constant(1 : i32) : i32
431+
// CHECK-NEXT: %[[OFFSET_X:.*]] = llvm.udiv %[[ADD]], %[[NUM_PACKED_VALS]] : i32
416432
// CHECK: llvm.mlir.undef : vector<8xf16>
417433
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
418434
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[WIDTH_IN_BYTES]], %[[HEIGHT]], %[[ROW_STRIDE_IN_BYTES]], %[[OFFSET_X]], %[[OFFSET_Y]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,6 +2901,12 @@ struct StoreOpToBlockIOConversion
29012901
// The offsetX is number of elements instead of packed elements.
29022902
addrElem = b.gep(ptr_ty(ctx, 1), eltTy, addrElem, offsetX);
29032903
offsetX = b.i32_val(0);
2904+
} else {
2905+
assert(numPackedVals > 0 &&
2906+
"numPackedVals should be greater than zero.");
2907+
// The offsetX of linear layout is in original elements.
2908+
// The 2d block io requires the offsetX in number of packed elements.
2909+
offsetX = b.udiv(offsetX, b.i32_val(numPackedVals));
29042910
}
29052911
if (!boundaryCheck.contains(rowDim)) {
29062912
baseHeight = b.i32_val(tileHeight);

0 commit comments

Comments
 (0)