Skip to content

Commit dfc5a66

Browse files
authored
[LoadStoreOpToLLVM] Improve the 2D block IO lowering for DPAS and DotOp layout. (#5425)
In the DPAS layout, three data types are involved in the block load and dot product (DPAS) computation flow: 1. load2DGenXType – represents the raw data loaded from memory. 2. packedDPASOperandType – the packed form of data used as DPAS operands. 3. unpackedType – the unpacked form used for intermediate transformations. For non-DPAS layouts, only the first two types are used. The data flow proceeds as follows: 1. A 2D block load operation fetches data into `load2DGenXType` values. 2. Vector shuffles reorganize the loaded data into DPAS operand fragments. 3. Bitcasts convert between packed and unpacked representations to prepare data for computation. 4. The `tt.dot` (DPAS) operation consumes packed operands to perform the dot product. During optimization, redundant pack/unpack and bitcast operations are removed, resulting in a simplified sequence: - A single block load (load_2d) - Shuffle operations defining operand layout - A DPAS instruction consuming packed operands Conceptually, the combination of `packedDPASOperandType` and `shufflevector` determines how input data maps to the DPAS computation flow. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 462a07e commit dfc5a66

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
369369
tt.func public @regular_pointer_block_io(%arg0: !tt.ptr<f16>) {
370370

371371
%a_mask = arith.constant dense<true> : tensor<256x64xi1, #mma>
372-
%a_other = arith.constant dense<0.00e+00> : tensor<256x64xf16, #mma>
372+
%a_other = arith.constant dense<1.00e+00> : tensor<256x64xf16, #mma>
373373
// CHECK-NOT: llvm.cond_br
374374

375375
%0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #mma}>>
@@ -389,7 +389,6 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
389389
// CHECK: %[[TOP_LEFT_MASK_BOOL_64:.*]] = llvm.extractvalue {{.*}}[64] : !llvm.struct<(i1, i1, {{.*}}
390390
// CHECK: %[[TOP_LEFT_MASK_BOOL_96:.*]] = llvm.extractvalue {{.*}}[96] : !llvm.struct<(i1, i1, {{.*}}
391391

392-
393392
// CHECK: %[[BLOCK_SHAPE_Y:.*]] = llvm.mlir.constant(16 : i32) : i32
394393
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
395394
// CHECK: %[[VAL_2886:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -402,6 +401,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
402401
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
403402
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
404403
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
404+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
405+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
406+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
407+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
408+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
409+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
410+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
411+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
412+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
413+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
414+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
415+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
405416

406417
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
407418
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -414,6 +425,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
414425
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
415426
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
416427
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
428+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
429+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
430+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
431+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
432+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
433+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
434+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
435+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
436+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
437+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
438+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
439+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
417440

418441
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
419442
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -426,6 +449,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
426449
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
427450
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
428451
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
452+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
453+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
454+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
455+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
456+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
457+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
458+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
459+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
460+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
461+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
462+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
463+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
429464

430465
// CHECK: %[[TOP_LEFT_PTR:.*]] = llvm.ptrtoint {{.*}} : !llvm.ptr<1> to i64
431466
// CHECK: %[[VAL_3046:.*]] = llvm.call spir_funccc @_Z17sub_group_shufflelj(%[[TOP_LEFT_PTR]], {{.*}}) {convergent, no_unwind, will_return} : (i64, i32) -> i64
@@ -438,6 +473,18 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32, "ttg.thr
438473
// CHECK: %[[PRED_BOOL:.*]] = llvm.trunc %[[PRED]] : i8 to i1
439474
// CHECK: %[[BASE_Y_0:.*]] = llvm.select %[[PRED_BOOL]], %[[CST0_1]], %[[BLOCK_SHAPE_Y]] : i1, i32
440475
// CHECK: %[[LOAD_0:.*]] = triton_gen.2Dblockload {{.*}}, %[[BASE_Y_0]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 16, v_blocks = 2
476+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<32xi16>
477+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
478+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
479+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [8, 9, 10, 11, 12, 13, 14, 15] : vector<32xi16>
480+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
481+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
482+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [16, 17, 18, 19, 20, 21, 22, 23] : vector<32xi16>
483+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
484+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
485+
// CHECK: %[[DECOMPOSED_DATA:.*]] = llvm.shufflevector %[[LOAD_0]], %[[LOAD_0]] [24, 25, 26, 27, 28, 29, 30, 31] : vector<32xi16>
486+
// CHECK-NEXT: %[[UNPACKED_TYPE:.*]] = llvm.bitcast %[[DECOMPOSED_DATA]] : vector<8xi16> to vector<8xf16>
487+
// CHECK: llvm.select %[[PRED_BOOL]], %[[UNPACKED_TYPE]], {{.*}} : i1, vector<8xf16>
441488
%11 = tt.load %10, %a_mask, %a_other {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
442489

443490
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2682,10 +2682,66 @@ struct LoadOpToBlockIOConversion
26822682

26832683
bool useVNNIFormat = false;
26842684
Type packedDPASOperandType;
2685-
if (hasDotDpasEncoding(tensorType)) {
2685+
if (hasDpasEncoding(tensorType) || hasDotDpasEncoding(tensorType)) {
2686+
2687+
// For the DPAS layout, there are three types of block loads used.
2688+
// (For non-DPAS layouts, only two types are involved.)
2689+
// 1. load2DGenXType –
2690+
// 2. packedDPASOperandType – (This is null for non-DPAS layouts.)
2691+
// 3. unpackedType –
2692+
//
2693+
// clang-format off
2694+
// The `tt.load` operation generates the following block load sequence:
2695+
// %0 = load_2d %ptr : <load2DGenXType>
2696+
// %1 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2697+
// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2698+
// %2 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2699+
// <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
2700+
// %3 = bitcast %1 : <packedDPASOperandType> -> <unpackedType>
2701+
// %4 = bitcast %2 : <packedDPASOperandType> -> <unpackedType>
2702+
// <operations for packLLElements>
2703+
// clang-format on
2704+
//
2705+
// The `tt.dot` operation generates the DPAS instruction sequence:
2706+
// clang-format off
2707+
// <operations for unpackLLElements>
2708+
// %5 = bitcast %3 : <unpackedType> -> <packedDPASOperandType>
2709+
// %6 = bitcast %4 : <unpackedType> -> <packedDPASOperandType>
2710+
// %7 = dpas %5, %6, %other : <packedDPASOperandType>, <packedDPASOperandType>, <packedDPASOperandType>
2711+
// clang-format on
2712+
//
2713+
// The LLVM optimizer eliminates redundant pack/unpack element pairs
2714+
// and corresponding bitcast operations. The final optimized IR for
2715+
// the dot product becomes:
2716+
//
2717+
// clang-format off
2718+
// %0 = load_2d %ptr : <load2DGenXType>
2719+
// %1 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2720+
// <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
2721+
// %2 = shufflevector <load2DGenXType> %0, <load2DGenXType> %0,
2722+
// <8 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
2723+
// %3 = dpas %1, %2, %other : <packedDPASOperandType>, <packedDPASOperandType>, <packedDPASOperandType>
2724+
// clang-format on
2725+
//
2726+
// The `packedDPASOperandType` together with the `shufflevector`
2727+
// operations defines the computation flow for the dot product.
2728+
26862729
DpasEncodingAttr::OpIdx opIdx = getOpIdx(tensorType);
26872730
auto dpasLayout = getDpasLayout(tensorType);
2688-
if (opIdx == DpasEncodingAttr::OpIdx::OperandB) {
2731+
switch (opIdx) {
2732+
case DpasEncodingAttr::OpIdx::OperandA: {
2733+
unsigned elemsPerLanePerDPASInst =
2734+
product<unsigned>(dpasLayout.getDPASInstShapeA()) / threadsPerWarp;
2735+
// Block 2D contain at least one DotOp A.
2736+
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
2737+
packedDPASOperandType = LLVM::getVectorType(
2738+
packedType, elemsPerLanePerDPASInst / numPackedVals);
2739+
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
2740+
}
2741+
} break;
2742+
case DpasEncodingAttr::OpIdx::OperandB: {
2743+
assert(numPackedVals == 1 &&
2744+
"invalid number of packed values for DPAS operand B.");
26892745
unsigned elemsPerLanePerDPASInst =
26902746
product<unsigned>(dpasLayout.getDPASInstShapeB()) / threadsPerWarp;
26912747
// Block 2D contain at least one DotOp B.
@@ -2709,6 +2765,17 @@ struct LoadOpToBlockIOConversion
27092765
}
27102766
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
27112767
}
2768+
} break;
2769+
case DpasEncodingAttr::OpIdx::OperandC: {
2770+
unsigned elemsPerLanePerDPASInst =
2771+
product<unsigned>(dpasLayout.getDPASInstShapeC()) / threadsPerWarp;
2772+
// Block 2D contain at least one DotOp C.
2773+
if (numElemsPerLoad >= elemsPerLanePerDPASInst) {
2774+
packedDPASOperandType = LLVM::getVectorType(
2775+
packedType, elemsPerLanePerDPASInst / numPackedVals);
2776+
unpackedType = LLVM::getVectorType(eltTy, elemsPerLanePerDPASInst);
2777+
}
2778+
} break;
27122779
}
27132780
}
27142781
SmallVector<Value> unpackedLoadedVals(numElems);

0 commit comments

Comments
 (0)