Skip to content

Commit ac42933

Browse files
committed
To load single DPAS B matrix per 2D block io instruction from the column major matrix in memory gets better performance for flash attention.
1 parent eff7b30 commit ac42933

File tree

2 files changed

+3
-23
lines changed

2 files changed

+3
-23
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -204,22 +204,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
204204
%c0_i32 = arith.constant 0 : i32
205205
%c32_i64 = arith.constant 32 : i64
206206
%21 = tt.make_tensor_ptr %arg0, [%c64_i64, %c64_i64], [%c1_i64, %col_stride], [%c0_i32, %c0_i32] {order = array<i32: 0, 1>} : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
207-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
208-
// CHECK: %[[VAL_68:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
209-
// CHECK: %[[VAL_69:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
210-
// CHECK: %[[VAL_71:.*]] = llvm.shufflevector %[[VAL_68]], %[[VAL_68]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
211-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
212-
// CHECK: %[[VAL_103:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
213-
// CHECK: %[[VAL_104:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
214-
// CHECK: %[[VAL_106:.*]] = llvm.shufflevector %[[VAL_103]], %[[VAL_103]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
215-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
216-
// CHECK: %[[VAL_138:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
217-
// CHECK: %[[VAL_139:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
218-
// CHECK: %[[VAL_141:.*]] = llvm.shufflevector %[[VAL_138]], %[[VAL_138]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
219-
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_32r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
220-
// CHECK: %[[VAL_173:.*]] = llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
221-
// CHECK: %[[VAL_174:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [0, 2, 4, 6, 8, 10, 12, 14] : vector<16xi32>
222-
// CHECK: %[[VAL_176:.*]] = llvm.shufflevector %[[VAL_173]], %[[VAL_173]] [1, 3, 5, 7, 9, 11, 13, 15] : vector<16xi32>
207+
// CHECK-COUNT-8: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj({{.*}}, [[DEST:%.*]]) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
223208
%45 = tt.load %21 {triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
224209
tt.return
225210
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -621,13 +621,8 @@ struct LoadOpConversion
621621

622622
std::swap(tileHeight, tileWidth);
623623

624-
// We can decompose the matrix returned by transposed large 2d load
625-
// when threads per warp < column size. Otherwise we have to load one
626-
// operand per inst.
627-
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
628-
// now.
629-
numOperandsPer2DLoadM =
630-
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
624+
// Only load 1 operand per inst on row.
625+
numOperandsPer2DLoadM = 1;
631626
// The transpose 2d load only support 1 operand per inst on column.
632627
// (vBlocks = 1)
633628
numOperandsPer2DloadN = 1;

0 commit comments

Comments
 (0)