Skip to content

Commit ee755e8

Browse files
authored
To load single DPAS B matrix instead of two per 2D block io instruction from the transposed memory (#2628)
To load single DPAS B matrix per 2D block io instruction from the column major matrix in memory gets better performance for flash attention. Because unlike the row major matrix, the values, which includes more than one DPAS B operands returned by a single 2D transposed block IO, cannot be used as DPAS operands directly. We have to shuffle the value in the register before pass it to the DPAS instruction and this is not optimized by the IGC for now.
1 parent 9952acf commit ee755e8

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3434
"TRITON_INTEL_ADVANCED_PATH",
3535
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
3636
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
37+
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
3738
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
3839
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
3940
"TRITON_INTEL_ENABLE_INSTR_SCHED",

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2+
// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
23

34
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
45
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
@@ -204,22 +205,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
204205
%c0_i32 = arith.constant 0 : i32
205206
%c32_i64 = arith.constant 32 : i64
206207
%21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
207-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
208-
// CHECK: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
209-
// CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
210-
// CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
211-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
212-
// CHECK: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
213-
// CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
214-
// CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
215-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
216-
// CHECK: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
217-
// CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
218-
// CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
219-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
220-
// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
221-
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
222-
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
208+
// COM: One DPAS operand B per load instruction.
209+
// SMALL-BLOCK-SIZE-TRANS-B-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
210+
// COM: Two interleaved DPAS operand B per load instruction. Need to shuffle the loaded value to decompose the VNNI format DPAS operand B.
211+
// LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
212+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
213+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
214+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
215+
// LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
216+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
217+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
218+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
219+
// LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
220+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
221+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
222+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
223+
// LARGE-BLOCK-SIZE-TRANS-B: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
224+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
225+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
226+
// LARGE-BLOCK-SIZE-TRANS-B: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
223227
%45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
224228
tt.return
225229
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,19 @@ struct LoadOpConversion
627627

628628
std::swap(tileHeight, tileWidth);
629629

630-
// We can decompose the matrix returned by transposed large 2d load
631-
// when threads per warp < column size. Otherwise we have to load one
632-
// operand per inst.
633-
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
634-
// now.
635-
numOperandsPer2DLoadM =
636-
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
630+
if (triton::tools::getBoolEnv(
631+
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) {
632+
// Only load 1 operand per inst on row.
633+
numOperandsPer2DLoadM = 1;
634+
} else {
635+
// We can decompose the matrix returned by transposed large 2d load
636+
// when threads per warp < column size. Otherwise we have to load one
637+
// operand per inst.
638+
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
639+
// now.
640+
numOperandsPer2DLoadM =
641+
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
642+
}
637643
// The transpose 2d load only support 1 operand per inst on column.
638644
// (vBlocks = 1)
639645
numOperandsPer2DloadN = 1;

0 commit comments

Comments
 (0)