Skip to content

Commit f4495f7

Browse files
[TritonGEN] Reduce GenISA usage (#4294)
SPV variant `16b_16r16x2c` is available, no need to use `GenISA`. Benchmark CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15216714545 (no geomean regression) Flex Decoding CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/15218710307 (no regression) --------- Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9e6f975 commit f4495f7

File tree

3 files changed

+45
-16
lines changed

3 files changed

+45
-16
lines changed

test/TritonGEN/tritongen-2Dblockload-to-llvm.mlir

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,13 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
338338
// -----
339339

340340
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
341-
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32i16({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg5, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) -> vector<32xi16>
341+
// CHECK: llvm.mlir.constant(2 : i32) : i32
342+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32
343+
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32
344+
// CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(16 : i32) : i32
345+
// CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(2 : i32) : i32
346+
// CHECK-NEXT: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], %arg0, [[ADD_0]], %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
347+
// CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<32xi16>
342348
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<32xi16>
343349
llvm.return
344350
}
@@ -469,7 +475,13 @@ llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_
469475
// -----
470476

471477
llvm.func @triton_gen.2Dblockload(%ptr : !llvm.ptr<1>, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
472-
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v16i32({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %arg5, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) -> vector<16xi32>
478+
// CHECK: llvm.mlir.constant(2 : i32) : i32
479+
// CHECK: [[ElemSize:%.*]] = llvm.mlir.constant(2 : i32) : i32
480+
// CHECK-NEXT: [[TileWidth:%.*]] = llvm.mlir.constant(16 : i32) : i32
481+
// CHECK-NEXT: [[TileHeight:%.*]] = llvm.mlir.constant(16 : i32) : i32
482+
// CHECK-NEXT: [[VBlocks:%.*]] = llvm.mlir.constant(2 : i32) : i32
483+
// CHECK-NEXT: llvm.call spir_funccc @_Z41__spirv_Subgroup2DBlockLoadTransformINTELiiiiPU3AS1viiiDv2_iPv([[ElemSize]], [[TileWidth]], [[TileHeight]], [[VBlocks]], %arg0, [[ADD_0]], %arg2, %arg3, {{.*}}, [[DEST:%.*]]) {{.*}} : (i32, i32, i32, i32, !llvm.ptr<1>{{.*}}, i32, i32, i32, vector<2xi32>, !llvm.ptr{{.*}}) -> ()
484+
// CHECK-NEXT: llvm.load [[DEST]] : !llvm.ptr -> vector<16xi32>
473485
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=16, tile_height=16, v_blocks=2, transpose=false, vnni_transform=true, cache_control=Default} : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi32>
474486
llvm.return
475487
}

test/TritonIntelGPU/tensor-pointer-load-block-2d.mlir

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,14 @@ module attributes {ttig.support_sg_2d_block, "ttg.num-warps" = 8 : i32} {
158158
// CHECK: [[C8:%.*]] = llvm.mlir.constant(8 : i32) : i32
159159
// CHECK: [[C16:%.*]] = llvm.mlir.constant(16 : i32) : i32
160160

161-
// CHECK-COUNT-4: llvm.call spir_funccc @llvm.genx.GenISA.LSC2DBlockRead.v32f16
161+
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
162+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
163+
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
164+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
165+
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
166+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
167+
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32
168+
// CHECK: llvm.call spir_funccc @_Z32__spirv_Subgroup2DBlockLoadINTELiiiiPU3AS1viiiDv2_iPv([[C2]], [[C16]], [[C16]], [[C2]], {{.*}})
162169
%0 = tt.load %arg0 {ttig.block_io = "row_major"} : tensor<256x64x!tt.ptr<f16>, #mma>
163170

164171
// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : i32

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,28 @@ loadCacheControlToCacheControls(Builder &builder,
110110
return builder.getAttr<TritonGEN::DecorationCacheControlAttr>(decorations);
111111
}
112112

113-
static bool isOCLBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
114-
// The following signature is not valid in OCL interface.
115-
// _Z42intel_sub_group_2d_block_read_16b_16r16x2cPU3AS1viiiDv2_iPDh
116-
if (op.getElemSizeInBits() == 16 && op.getTileHeight() == 16 &&
117-
op.getTileWidth() == 16 && op.getVBlocks() == 2) {
118-
return false;
119-
}
120-
121-
if (op.getElemSizeInBits() == 8 && op.getTileWidth() == 16 &&
122-
op.getVBlocks() != 4 && !op.getVnniTransform()) {
123-
// TODO: add ocl builtin/spirv intrinsics for 8b 16 column 1 vBlock & 2
124-
// vBlock reads
113+
static bool isSPVBuiltinAvailable(TritonGEN::Matrix2DBlockLoadOp op) {
114+
// FIXME: The following signatures are not valid in SPV interface.
115+
// intel_sub_group_2d_block_read_8b_32r16x1c
116+
// intel_sub_group_2d_block_read_8b_32r16x2c
117+
// intel_sub_group_2d_block_read_8b_16r16x2c
118+
// intel_sub_group_2d_block_read_8b_8r16x1c
119+
// intel_sub_group_2d_block_read_8b_8r16x2c
120+
if ((op.getElemSizeInBits() == 8 && op.getTileHeight() == 32 &&
121+
op.getTileWidth() == 16 && op.getVBlocks() == 1 &&
122+
!op.getVnniTransform()) ||
123+
(op.getElemSizeInBits() == 8 && op.getTileHeight() == 32 &&
124+
op.getTileWidth() == 16 && op.getVBlocks() == 2 &&
125+
!op.getVnniTransform()) ||
126+
(op.getElemSizeInBits() == 8 && op.getTileHeight() == 16 &&
127+
op.getTileWidth() == 16 && op.getVBlocks() == 2 &&
128+
!op.getVnniTransform()) ||
129+
(op.getElemSizeInBits() == 8 && op.getTileHeight() == 8 &&
130+
op.getTileWidth() == 16 && op.getVBlocks() == 1 &&
131+
!op.getVnniTransform()) ||
132+
(op.getElemSizeInBits() == 8 && op.getTileHeight() == 8 &&
133+
op.getTileWidth() == 16 && op.getVBlocks() == 2 &&
134+
!op.getVnniTransform())) {
125135
return false;
126136
}
127137

@@ -490,7 +500,7 @@ struct TritonMatrix2DBlockLoadLowering
490500
LogicalResult
491501
matchAndRewrite(TritonGEN::Matrix2DBlockLoadOp op, OpAdaptor adaptor,
492502
ConversionPatternRewriter &rewriter) const override {
493-
if (!isOCLBuiltinAvailable(op)) {
503+
if (!isSPVBuiltinAvailable(op)) {
494504
// Fallback to GenISA interface.
495505
rewriter.replaceOp(op, createGenISA2DBlockRead(op, rewriter));
496506
return success();

0 commit comments

Comments
 (0)