Skip to content

Commit 458d741

Browse files
committed
[LoadStoreOpToLLVM] Improve the 2D block IO lowering for DPAS and DotOp layout.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent da48f1a commit 458d741

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
369369
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f16>) {
370370

371371
%a_mask = arith.constant dense<true> : tensor<256x64xi1, #mma>
372-
%a_other = arith.constant dense<0.00e+00> : tensor<256x64xf16, #mma>
372+
%a_other = arith.constant dense<1.00e+00> : tensor<256x64xf16, #mma>
373373
// CHECK-NOT: llvm.cond_br
374374

375375
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #mma}>>
@@ -388,7 +388,8 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
388388
// CHECK: %[[TOP_LEFT_MASK_BOOL_32:.*]] = llvm.extractvalue {{.*}}[32] : !llvm.struct<(i1, i1, {{.*}}
389389
// CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
390390
// CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}
391-
391+
// CHECK: %[[VAL_2878:.*]] = llvm.extractvalue {{.*}}[126] : !llvm.struct<(f16, f16, {{.*}}
392+
// CHECK: %[[VAL_2879:.*]] = llvm.extractvalue {{.*}}[127] : !llvm.struct<(f16, f16, {{.*}}
392393

393394
// CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
394395
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
@@ -402,6 +403,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
402403
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
403404
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
404405
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
406+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
407+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
408+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
409+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
410+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
411+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
412+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
413+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
414+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
415+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
416+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
417+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
405418

406419
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
407420
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -414,6 +427,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
414427
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
415428
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
416429
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
430+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
431+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
432+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
433+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
434+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
435+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
436+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
437+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
438+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
439+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
440+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
441+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
417442

418443
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
419444
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -426,6 +451,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
426451
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
427452
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
428453
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
454+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
455+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
456+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
457+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
458+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
459+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
460+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
461+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
462+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
463+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
464+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
465+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
429466

430467
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
431468
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -438,6 +475,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
438475
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
439476
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
440477
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
478+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
479+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
480+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
481+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
482+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
483+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
484+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
485+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
486+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
487+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
488+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
489+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
441490
%11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
442491

443492
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,10 +2682,66 @@ struct LoadOpToBlockIOConversion
26822682

26832683
bool useVNNIFormat = false;
26842684
Type packedDPASOperandType;
2685-
if (hasDotDpasEncoding(tensorType)) {
2685+
if (hasDpasEncoding(tensorType) || hasDotDpasEncoding(tensorType)) {
2686+
2687+
// For the DPAS layout, there are three types of block loads used.
2688+
// (For non-DPAS layouts, only two types are involved.)
2689+
// 1. load2DGenXType –
2690+
// 2. packedDPASOperandType – (This is null for non-DPAS layouts.)
2691+
// 3. unpackedType –
2692+
//
2693+
// clang-format off
2694+
// The `tt.load` operation generates the following block load sequence:
2695+
// %0 = load_2d %ptr : <load2DGenXType>
2696+
// %1 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2697+
// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2698+
// %2 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2699+
// <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
2700+
// %3 = bitcast %1 : <packedDPASOperandType> -> <unpackedType>
2701+
// %4 = bitcast %2 : <packedDPASOperandType> -> <unpackedType>
2702+
// <operations for packLLElements>
2703+
// clang-format on
2704+
//
2705+
// The `tt.dot` operation generates the DPAS instruction sequence:
2706+
// clang-format off
2707+
// <operations for unpackLLElements>
2708+
// %5 = bitcast %3 : <unpackedType> -> <packedDPASOperandType>
2709+
// %6 = bitcast %4 : <unpackedType> -> <packedDPASOperandType>
2710+
// %7 = dpas %5, %6, %other : <packedDPASOperandType>, <packedDPASOperandType>, <packedDPASOperandType>
2711+
// clang-format on
2712+
//
2713+
// The LLVM optimizer eliminates redundant pack/unpack element pairs
2714+
// and corresponding bitcast operations. The final optimized IR for
2715+
// the dot product becomes:
2716+
//
2717+
// clang-format off
2718+
// %0 = load_2d %ptr : <load2DGenXType>
2719+
// %1 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2720+
// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2721+
// %2 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2722+
// <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
2723+
// %3 = dpas %1, %2, %other : <packedDPASOperandType>, <packedDPASOperandType>, <packedDPASOperandType>
2724+
// clang-format on
2725+
//
2726+
// The `packedDPASOperandType` together with the `shufflevector`
2727+
// operations defines the computation flow for the dot product.
2728+
26862729
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
26872730
auto dpasLayout = getDpasLayout(tensorType);
2688-
if (opIdx == DpasEncodingAttr::OpIdx::OperandB) {
2731+
switch (opIdx) {
2732+
case DpasEncodingAttr::OpIdx::OperandA: {
2733+
unsigned elemsPerLanePerDPASInst =
2734+
product<unsigned>(dpasLayout.getDPASInstShapeA()) / threadsPerWarp;
2735+
// Block 2D contain at least one DotOp A.
2736+
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
2737+
packedDPASOperandType = LLVM::getVectorType(
2738+
packedType, elemsPerLanePerDPASInst / numPackedVals);
2739+
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
2740+
}
2741+
} break;
2742+
case DpasEncodingAttr::OpIdx::OperandB: {
2743+
assert(numPackedVals == 1 &&
2744+
"invalid number of packed values for DPAS operand B.");
26892745
unsigned elemsPerLanePerDPASInst =
26902746
product<unsigned>(dpasLayout.getDPASInstShapeB()) / threadsPerWarp;
26912747
// Block 2D contain at least one DotOp B.
@@ -2709,6 +2765,19 @@ struct LoadOpToBlockIOConversion
27092765
}
27102766
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
27112767
}
2768+
} break;
2769+
case DpasEncodingAttr::OpIdx::OperandC: {
2770+
unsigned elemsPerLanePerDPASInst =
2771+
product<unsigned>(dpasLayout.getDPASInstShapeC()) / threadsPerWarp;
2772+
// Block 2D contain at least one DotOp C.
2773+
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
2774+
packedDPASOperandType = LLVM::getVectorType(
2775+
packedType, elemsPerLanePerDPASInst / numPackedVals);
2776+
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
2777+
}
2778+
} break;
2779+
default:
2780+
llvm_unreachable("unexpected OpIdx type.");
27122781
}
27132782
}
27142783
SmallVector<Value> unpackedLoadedVals(numElems);

0 commit comments

Comments
 (0)