Skip to content

Commit 013c838

Browse files
committed
Add a configurable env var to disable large block size of loading the transposed DPAS operand B. It can get better performance for flash attention because it works around performance issue of register spilling.
1 parent 6da6e84 commit 013c838

File tree

3 files changed

+35
-24
lines changed

3 files changed

+35
-24
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3434
"TRITON_INTEL_ADVANCED_PATH",
3535
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
3636
"TRITON_INTEL_DO_NOT_SINK_INSTR_ACROSS_RGN",
37+
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B",
3738
"TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT",
3839
"TRITON_INTEL_ENABLE_FIRST_LOAD_TO_SLM",
3940
"TRITON_INTEL_ENABLE_INSTR_SCHED",

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
1+
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,LARGE-BLOCK-SIZE-TRANS-B
2+
// RUN: TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --check-prefixes=CHECK,SMALL-BLOCK-SIZE-TRANS-B
23

34
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
45
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
@@ -204,22 +205,25 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
204205
%c0_i32 = arith.constant 0 : i32
205206
%c32_i64 = arith.constant 32 : i64
206207
%21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
207-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
208-
// CHECK: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
209-
// CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
210-
// CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
211-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
212-
// CHECK: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
213-
// CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
214-
// CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
215-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
216-
// CHECK: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
217-
// CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
218-
// CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
219-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
220-
// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
221-
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
222-
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
208+
// COM: One DPAS operand B per load instruction.
209+
// SMALL-BLOCK-SIZE-TRANS-B-CHECK-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
210+
// COM: Two interleaved DPAS operand B per load instruction. Need to shuffle the loaded value to decompose the VNNI format DPAS operand B.
211+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
212+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
213+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
214+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
215+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
216+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
217+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
218+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
219+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
220+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
221+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
222+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
223+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
224+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
225+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
226+
// LARGE-BLOCK-SIZE-TRANS-B-CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
223227
%45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
224228
tt.return
225229
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -627,13 +627,19 @@ struct LoadOpConversion
627627

628628
std::swap(tileHeight, tileWidth);
629629

630-
// We can decompose the matrix returned by transposed large 2d load
631-
// when threads per warp < column size. Otherwise we have to load one
632-
// operand per inst.
633-
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
634-
// now.
635-
numOperandsPer2DLoadM =
636-
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
630+
if (triton::tools::getBoolEnv(
631+
"TRITON_INTEL_DISABLE_LARGE_BLOCK_SIZE_IO_FOR_TRANS_DOT_B")) {
632+
// Only load 1 operand per inst on row.
633+
numOperandsPer2DLoadM = 1;
634+
} else {
635+
// We can decompose the matrix returned by transposed large 2d load
636+
// when threads per warp < column size. Otherwise we have to load one
637+
// operand per inst.
638+
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
639+
// now.
640+
numOperandsPer2DLoadM =
641+
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
642+
}
637643
// The transpose 2d load only support 1 operand per inst on column.
638644
// (vBlocks = 1)
639645
numOperandsPer2DloadN = 1;

0 commit comments

Comments
 (0)