Skip to content

Commit 9837733

Browse files
[TritonGEN] Update unsupported SPV 2D block load (#4830)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 21933fb commit 9837733

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
9797
// CHECK: llvm.func spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv(i32, i32, i32, i32, !llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
9898

9999
module attributes {"ttg.threads-per-warp" = 16 : i32} {
100-
llvm.func @triton_gen.2Dblockload2(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
100+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
101101
// CHECK: [[C32:%.*]] = llvm.mlir.constant(32 : i32) : i32
102102
// CHECK-NEXT: [[DEST:%.*]] = llvm.alloca [[C32]] x i8 : (i32) -> !llvm.ptr
103103
// CHECK-NEXT: [[PTRTOINT:%.*]] = llvm.ptrtoint %arg0 : !llvm.ptr<1> to i64
@@ -162,6 +162,22 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
162162

163163
// -----
164164

165+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
166+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
167+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(8 : i32) : i32
168+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
169+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(16 : i32) : i32
170+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(4 : i32) : i32
171+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(false) : i1
172+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
173+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i8({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
174+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=16, v_blocks=4, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi8>
175+
llvm.return
176+
}
177+
}
178+
179+
// -----
180+
165181
module attributes {"ttg.threads-per-warp" = 16 : i32} {
166182
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
167183
// CHECK-COUNT-2: llvm.mlir.constant(1 : i32) : i32

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
140140
op.getTileWidth() == 16 && op.getVBlocks() == 2 && !op.getVnniTransform())
141141
return false;
142142

143+
// intel_sub_group_2d_block_read_8b_16r8x4c
144+
if (op.getElemSizeInBits() == 8 && op.getTileHeight() == 16 &&
145+
op.getTileWidth() == 8 && op.getVBlocks() == 4 && !op.getVnniTransform())
146+
return false;
147+
143148
// intel_sub_group_2d_block_read_8b_16r16x2c
144149
if (op.getElemSizeInBits() == 8 && op.getTileHeight() == 16 &&
145150
op.getTileWidth() == 16 && op.getVBlocks() == 2 && !op.getVnniTransform())

0 commit comments

Comments
 (0)