Skip to content

Commit b361546

Browse files
chengjunluetiottowhitneywhtsang
authored
[TritonGEN]: Use GenISA block store if the sub group size is not equal to 16. (#4743)
The SPV block store interface only support sub group size = 16. Use GenISA block store if the sub group size is not equal to 16. --------- Signed-off-by: Lu,Chengjun <[email protected]> Signed-off-by: Tiotto, Ettore <[email protected]> Co-authored-by: Tiotto, Ettore <[email protected]> Co-authored-by: Whitney Tsang <[email protected]>
1 parent 6b65fa3 commit b361546

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// RUN: env TRITON_INTEL_ADVANCED_PATH=1 triton-opt %s --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm --split-input-file | FileCheck %s
22

3-
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 1 : i32} {
3+
module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32} {
44
// CHECK-DAG: llvm.func spir_funccc @_Z45__spirv_SubgroupMatrixMultiplyAccumulateINTELiDv8_sDv8_iDv8_fi(i32, vector<8xi16>, vector<8xi32>, vector<8xf32>, i32) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
55
// CHECK-DAG: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
66
// CHECK-DAG: llvm.func spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}

test/TritonGEN/tritongen-2Dblockstore-to-llvm.mlir

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s
22

3+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
34
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
45
// CHECK: [[ONE0:%.*]] = llvm.mlir.constant(1 : i32) : i32
56
// CHECK: [[PTR:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
@@ -25,11 +26,11 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
2526
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
2627
llvm.return
2728
}
29+
}
2830

2931
// -----
3032

31-
// CHECK: llvm.func spir_funccc @_Z33__spirv_Subgroup2DBlockStoreINTELiiiiPvPU3AS1viiiDv2_i(i32, i32, i32, i32, !llvm.ptr {llvm.nonnull, llvm.readonly}, !llvm.ptr<1> {llvm.nonnull, llvm.writeonly}, i32, i32, i32, vector<2xi32>) attributes {no_unwind, will_return}
32-
33+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
3334
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
3435
// CHECK: llvm.func @triton_gen.2Dblockstore(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: vector<8xi8>) {
3536
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
@@ -61,7 +62,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
6162
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, cache_control=L1UC_L3UC} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi8>)
6263
llvm.return
6364
}
64-
65+
}
6566
// -----
6667

6768
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
@@ -78,6 +79,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
7879

7980
// -----
8081

82+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
8183
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
8284
// CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32
8385
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -88,6 +90,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
8890
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
8991
llvm.return
9092
}
93+
}
9194

9295
// -----
9396

@@ -133,6 +136,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
133136

134137
// -----
135138

139+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
136140
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi16>) {
137141
// CHECK: llvm.mlir.constant(2 : i32) : i32
138142
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32
@@ -143,6 +147,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
143147
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=16, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi16>)
144148
llvm.return
145149
}
150+
}
146151

147152
// -----
148153

@@ -188,6 +193,7 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
188193

189194
// -----
190195

196+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
191197
llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi32>) {
192198
// CHECK: llvm.mlir.constant(4 : i32) : i32
193199
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(4 : i32) : i32
@@ -198,3 +204,4 @@ llvm.func @triton_gen.2Dblockstore(%ptr : !llvm.ptr<1>, %base_width : i32, %base
198204
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=32, tile_width=16, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
199205
llvm.return
200206
}
207+
}

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -181,16 +181,10 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockStoreOp op) {
181181
op.getTileWidth() == 8 && op.getVBlocks() == 1)
182182
return false;
183183

184-
// FIXME: The following signatures have correctness issue with SPV interface.
185-
186-
// intel_sub_group_2d_block_write_8b_1r32x1c
187-
if (op.getElemSizeInBits() == 8 && op.getTileHeight() == 1 &&
188-
op.getTileWidth() == 32 && op.getVBlocks() == 1)
189-
return false;
190-
191-
// intel_sub_group_2d_block_write_16b_2r16x1c
192-
if (op.getElemSizeInBits() == 16 && op.getTileHeight() == 2 &&
193-
op.getTileWidth() == 16 && op.getVBlocks() == 1)
184+
// FIXME: The SPV block store only support subgroup size 16.
185+
int subGroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
186+
op->getParentOfType<mlir::ModuleOp>());
187+
if (subGroupSize != 16)
194188
return false;
195189

196190
return true;

0 commit comments

Comments
 (0)