Skip to content

Commit 26af8c1

Browse files
[LoadStoreOpToLLVM] Check SupportSG2DBlockAttr for BlockIO lowering (#4647)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent bd694d9 commit 26af8c1

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
66
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
77
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
8-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
8+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
99
tt.func public @matmul_no_scf_with_advance_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg7: i64) {
1010
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
1111
%c0_i32 = arith.constant 0 : i32
@@ -29,7 +29,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
2929
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
3030
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
3131
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
32-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
32+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
3333
tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
3434
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
3535
%c0_i32 = arith.constant 0 : i32
@@ -57,7 +57,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
5757
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
5858
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
5959
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
60-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
60+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
6161
tt.func public @matmul_no_scf_with_add_transpose_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
6262
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
6363
%c0_i32 = arith.constant 0 : i32
@@ -83,7 +83,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
8383
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
8484
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
8585
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}>
86-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
86+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
8787
tt.func public @matmul_no_scf_with_advance_kernel(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i64, %arg3: i64, %arg4: i64, %arg5: i64, %arg7: i64) {
8888
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
8989
%c0_i32 = arith.constant 0 : i32
@@ -105,7 +105,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
105105
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
106106
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
107107
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
108-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
108+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
109109
// CHECK-LABEL: llvm.func spir_kernelcc @dot_op_a_2d_load(
110110
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
111111
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64, %[[VAL_4:.*]]: i64, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
@@ -168,7 +168,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
168168
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
169169
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
170170
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
171-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
171+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
172172
// CHECK-LABEL: llvm.func spir_kernelcc @dot_op_b_2d_load(
173173
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<1>,
174174
// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, %[[VAL_3:.*]]: i64, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
@@ -230,7 +230,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
230230

231231
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 2], A = [8, 16], B = [16, 32], C = [8, 32]}>
232232
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
233-
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
233+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
234234
// CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b
235235
tt.func public @column_major_dot_b(%arg0: !tt.ptr<f16>, %col_stride: i64) {
236236
%c64_i32 = arith.constant 64 : i32
@@ -263,7 +263,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32}
263263

264264
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
265265
#dot_b = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>
266-
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} {
266+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
267267
// CHECK-LABEL: llvm.func spir_kernelcc @column_major_dot_b
268268
tt.func public @column_major_dot_b(%arg0: !tt.ptr<f16>, %col_stride: i64) {
269269
%c64_i64 = arith.constant 64 : i64

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
44
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
55
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
6-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
6+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
77
tt.func public @matmul_no_scf_with_advance_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64) {
88
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
99
%c32_i32 = arith.constant 32 : i32
@@ -48,15 +48,15 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
4848
// CHECK: llvm.mlir.undef : vector<8xf16>
4949
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
5050
// CHECK: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
51-
tt.store %13, %12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x64xf16, #dpas>>
51+
tt.store %13, %12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>>
5252
tt.return
5353
}
5454
}
5555

5656
// -----
5757

5858
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
59-
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
59+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
6060
// CHECK-LABEL: llvm.func spir_kernelcc @dpas_layout_2d_store_rep_cluster_4_2(
6161
// CHECK-SAME: %[[base:.*]]: !llvm.ptr<1>,
6262
// CHECK-SAME: %[[width:.*]]: i64, %[[height:.*]]: i64, %[[rowStride:.*]]: i64, %[[PTR_1:.*]]: !llvm.ptr<1>) attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>} {
@@ -315,7 +315,7 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
315315
// CHECK: %[[VAL_406:.*]] = llvm.bitcast %[[VAL_405]] : vector<8xf16> to vector<8xi16>
316316
// CHECK: triton_gen.2Dblockstore %[[BASE_PTR]], %[[baseWidth]], %[[baseHeight]], %[[basePitch]], {{.*}}, %[[VAL_406]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
317317

318-
tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x32xf16, #dpas>>
318+
tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x32xf16, #dpas>>
319319
tt.return
320320
}
321321
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,26 @@ struct BlockIOConversionBase : public LoadStoreConversionBase {
292292
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
293293
: LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
294294

295-
// Determine whether the given LoadOp can be lowered to using block IO
295+
// Determine whether the given operation can be lowered to using block IO
296296
// instructions.
297-
static bool isLoadCandidate(triton::LoadOp op) {
297+
template <typename OpTy,
298+
std::enable_if_t<
299+
llvm::is_one_of<OpTy, triton::LoadOp, triton::StoreOp>::value,
300+
bool> = true>
301+
static bool isBlockIOCandidate(OpTy op) {
302+
ModuleOp mod = op->template getParentOfType<ModuleOp>();
303+
if (!mod->hasAttr(triton::gpu::intel::TritonIntelGPUDialect::
304+
getSupportSG2DBlockAttrName()))
305+
return false;
306+
298307
Attribute blockIOAttr =
299308
op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName());
300309
if (!blockIOAttr)
301310
return false;
302311

303-
// Only lower loadOp with dpas layout encoding.
304-
auto tensorTy = cast<RankedTensorType>(op.getType());
312+
// Only lower operation with dpas layout encoding.
313+
auto tensorTy =
314+
cast<RankedTensorType>(getPointeeType(op.getPtr().getType()));
305315
return hasDpasEncoding(tensorTy) || hasDotDpasEncoding(tensorTy);
306316
}
307317

@@ -773,9 +783,6 @@ struct LoadOpToBlockIOConversion
773783
assert(isTensorPointerType(ptr.getType()) &&
774784
"Expecting tensor pointer type");
775785

776-
if (!isLoadCandidate(op))
777-
return failure();
778-
779786
Location loc = op.getLoc();
780787
auto b = TritonLLVMOpBuilder(loc, rewriter);
781788
Value mask = op.getMask();
@@ -1565,7 +1572,7 @@ struct LoadOpToBlockIOConversion
15651572
LogicalResult
15661573
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
15671574
ConversionPatternRewriter &rewriter) const final {
1568-
if (!isLoadCandidate(op))
1575+
if (!isBlockIOCandidate(op))
15691576
return failure();
15701577

15711578
// 2D block io lowering steps:
@@ -2573,6 +2580,9 @@ struct StoreOpToBlockIOConversion
25732580
LogicalResult
25742581
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
25752582
ConversionPatternRewriter &rewriter) const final {
2583+
if (!isBlockIOCandidate(op))
2584+
return failure();
2585+
25762586
if (isTensorPointerType(op.getPtr().getType()))
25772587
return rewriteTensorPointerStore(op, adaptor, rewriter);
25782588
return failure();

0 commit comments

Comments
 (0)