Skip to content

Commit 80255de

Browse files
[TritonGEN] Lower to GenISA for 2d_block_read_transpose_32b_8r[2|4|8]x1c (#5083)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent dcb468d commit 80255de

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,54 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
638638

639639
// -----
640640

641+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
642+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
643+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
644+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(2 : i32) : i32
645+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
646+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
647+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(true) : i1
648+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
649+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v1i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
650+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=2, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<1xi32>
651+
llvm.return
652+
}
653+
}
654+
655+
// -----
656+
657+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
658+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
659+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
660+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(4 : i32) : i32
661+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
662+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
663+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(true) : i1
664+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
665+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v2i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
666+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=4, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<2xi32>
667+
llvm.return
668+
}
669+
}
670+
671+
// -----
672+
673+
module attributes {"ttg.threads-per-warp" = 16 : i32} {
674+
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
675+
// CHECK: [[ELEM_BITS:%.*]] = llvm.mlir.constant(32 : i32) : i32
676+
// CHECK: [[TILE_WIDTH:%.*]] = llvm.mlir.constant(8 : i32) : i32
677+
// CHECK: [[TILE_HEIGHT:%.*]] = llvm.mlir.constant(8 : i32) : i32
678+
// CHECK: [[VBLOCKS:%.*]] = llvm.mlir.constant(1 : i32) : i32
679+
// CHECK: [[TRANSPOSE:%.*]] = llvm.mlir.constant(true) : i1
680+
// CHECK: [[VNNI:%.*]] = llvm.mlir.constant(false) : i1
681+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v4i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, [[ELEM_BITS]], [[TILE_WIDTH]], [[TILE_HEIGHT]], [[VBLOCKS]], [[TRANSPOSE]], [[VNNI]], {{.*}})
682+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=8, tile_height=8, v_blocks=1, transpose=true, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<4xi32>
683+
llvm.return
684+
}
685+
}
686+
687+
// -----
688+
641689
module attributes {"ttg.threads-per-warp" = 16 : i32} {
642690
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
643691
// CHECK: llvm.mlir.constant(4 : i32) : i32

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,21 @@ static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
180180
op.getTileWidth() == 8 && op.getVBlocks() == 4 && !op.getVnniTransform())
181181
return false;
182182

183+
// intel_sub_group_2d_block_read_transpose_32b_8r2x1c
184+
if (op.getElemSizeInBits() == 32 && op.getTileHeight() == 8 &&
185+
op.getTileWidth() == 2 && op.getVBlocks() == 1 && op.getTranspose())
186+
return false;
187+
188+
// intel_sub_group_2d_block_read_transpose_32b_8r4x1c
189+
if (op.getElemSizeInBits() == 32 && op.getTileHeight() == 8 &&
190+
op.getTileWidth() == 4 && op.getVBlocks() == 1 && op.getTranspose())
191+
return false;
192+
193+
// intel_sub_group_2d_block_read_transpose_32b_8r8x1c
194+
if (op.getElemSizeInBits() == 32 && op.getTileHeight() == 8 &&
195+
op.getTileWidth() == 8 && op.getVBlocks() == 1 && op.getTranspose())
196+
return false;
197+
183198
// FIXME: The SPV block load only support subgroup size 16.
184199
int subGroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(
185200
op->getParentOfType<mlir::ModuleOp>());

0 commit comments

Comments
 (0)