Skip to content

Commit e423617

Browse files
[LoadStoreOpToLLVM] Add support of boundary check (#4701)
Prior to this PR, `StoreOpToBlockIOConversion` did not account for boundary checks, assuming they were always provided. This PR modifies `StoreOpToBlockIOConversion` to properly incorporate boundary checks. The implementation prevents triggering hardware boundary protection by expanding the base shape when boundary checks are not provided. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 3088449 commit e423617

File tree

2 files changed

+75
-10
lines changed

2 files changed

+75
-10
lines changed

test/TritonIntelGPU/blockptr_store.mlir

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
// RUN: triton-opt %s -split-input-file --convert-triton-intel-gpu-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm
22

33
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
4-
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
5-
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
64
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
7-
tt.func public @matmul_no_scf_with_advance_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg7: i64) {
5+
tt.func public @matmul_no_scf_with_advance_kernel(%base: !tt.ptr<f16>, %width: i64, %height: i64, %rowStride: i64) {
86
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #dpas>
9-
%c32_i32 = arith.constant 32 : i32
10-
%c-64_i32 = arith.constant -64 : i32
11-
%c-32_i32 = arith.constant -32 : i32
12-
%c64_i32 = arith.constant 64 : i32
137
%c0_i32 = arith.constant 0 : i32
148
%c1_i64 = arith.constant 1 : i64
15-
%13 = tt.make_tensor_ptr %arg2, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
9+
%0 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
1610
// CHECK: %[[WARP_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
1711
// CHECK: %[[offsetBaseY:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
1812
// CHECK: %[[offsetBaseX:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
@@ -42,7 +36,58 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32,
4236
// CHECK: llvm.mlir.undef : vector<8xf16>
4337
// CHECK-COUNT-8: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
4438
// CHECK: triton_gen.2Dblockstore {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[OFFSET_X]], {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
45-
tt.store %13, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>>
39+
tt.store %0, %cst {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>>
40+
tt.return
41+
}
42+
}
43+
44+
// -----
45+
46+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
47+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
48+
tt.func public @no_boundary_check(%base: !tt.ptr<f16>, %width: i64, %height: i64, %rowStride: i64) {
49+
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #dpas>
50+
%c0_i32 = arith.constant 0 : i32
51+
%c1_i64 = arith.constant 1 : i64
52+
%0 = tt.make_tensor_ptr %base, [%width, %height], [%rowStride, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #dpas>>
53+
54+
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i32) : i32
55+
// CHECK: %[[WARP_ID:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() {no_unwind, will_return} : () -> i32
56+
57+
// CHECK: %[[offsetBaseY:.*]] = llvm.extractvalue {{.*}}[0] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
58+
// CHECK: %[[offsetBaseX:.*]] = llvm.extractvalue {{.*}}[1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
59+
// CHECK: %[[baseHeight:.*]] = llvm.extractvalue {{.*}}[2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
60+
// CHECK: %[[baseWidth:.*]] = llvm.extractvalue {{.*}}[3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
61+
// CHECK: %[[rowStride:.*]] = llvm.extractvalue {{.*}}[4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
62+
// CHECK: %[[colStride:.*]] = llvm.extractvalue {{.*}}[5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
63+
// CHECK: %[[base:.*]] = llvm.extractvalue {{.*}}[6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
64+
65+
// CHECK: %[[rowStride_i32:.*]] = llvm.trunc %[[rowStride]] : i64 to i32
66+
// CHECK: %[[PITCH:.*]] = llvm.mul %[[rowStride_i32]], %[[C2]]
67+
// CHECK-COUNT-32: llvm.extractvalue {{.*}} : !llvm.struct<(f16, f16, {{.*}})>
68+
69+
// COM: Skip the register, lane, warp and block to the offset computation which should be covered by the LL tests.
70+
// CHECK: %[[OFFSET_X:.*]] = llvm.add %[[offsetBaseX]], {{.*}} : i32
71+
// CHECK: %[[OFFSET_Y:.*]] = llvm.add %[[offsetBaseY]], {{.*}} : i32
72+
73+
// COM: When boundary check is absent:
74+
// CHECK: %[[baseWidth:.*]] = llvm.mlir.constant(64 : i32)
75+
// CHECK: %[[base1:.*]] = llvm.getelementptr %[[base]][%[[OFFSET_X]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i16
76+
// CHECK: %[[OFFSET_X:.*]] = llvm.mlir.constant(0 : i32) : i32
77+
// CHECK: %[[baseHeight:.*]] = llvm.mlir.constant(8 : i32)
78+
// CHECK: %[[OFF:.*]] = llvm.mul %[[OFFSET_Y]], %[[PITCH]] : i32
79+
// CHECK: %[[base:.*]] = llvm.getelementptr %[[base1]][%[[OFF]]] : (!llvm.ptr<1>, i32) -> !llvm.ptr<1>, i8
80+
// CHECK: %[[OFFSET_Y:.*]] = llvm.mlir.constant(0 : i32) : i32
81+
82+
// CHECK: llvm.mlir.undef : vector<8xf16>
83+
// CHECK-COUNT-7: llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
84+
// CHECK: %[[VAL0:.*]] = llvm.insertelement %{{[0-9]+}}, %{{[0-9]+}}{{\[}}{{.*}} : i32] : vector<8xf16>
85+
// CHECK: %[[VAL:.*]] = llvm.bitcast %[[VAL0]] : vector<8xf16> to vector<8xi16>
86+
87+
// CHECK: triton_gen.2Dblockstore %[[base]], %[[baseWidth]], %[[baseHeight]], %[[PITCH]], %[[OFFSET_X]], %[[OFFSET_Y]], %[[VAL]] {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
88+
// CHECK-COUNT-3: triton_gen.2Dblockstore {{.*}} {elem_size_in_bits = 16, tile_width = 16, tile_height = 8, v_blocks = 1, cache_control = Default}
89+
90+
tt.store %0, %cst {ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #dpas>>
4691
tt.return
4792
}
4893
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2628,6 +2628,7 @@ struct StoreOpToBlockIOConversion
26282628

26292629
width = b.trunc(i32_ty, width);
26302630
rowStride = b.trunc(i32_ty, rowStride);
2631+
Value addrElem = base;
26312632
// encoded as bytes.
26322633
Value baseWidth = b.mul(width, elemSizeInBytes);
26332634
Value baseHeight = b.trunc(i32_ty, height);
@@ -2674,6 +2675,25 @@ struct StoreOpToBlockIOConversion
26742675
Value offsetX = b.add(offsetBaseX, offsets[colDim].second);
26752676
Value offsetY = b.add(offsetBaseY, offsets[rowDim].second);
26762677

2678+
// To prevent triggering hardware boundary protection, expand the base
2679+
// shape sufficiently when boundary check is absent.
2680+
SetVector<unsigned> boundaryCheck(op.getBoundaryCheck().begin(),
2681+
op.getBoundaryCheck().end());
2682+
if (!boundaryCheck.contains(colDim)) {
2683+
baseWidth = b.i32_val(
2684+
std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8)));
2685+
// Use opaqueType as offsetX is in number of elements.
2686+
addrElem = b.gep(ptr_ty(ctx, 1), opaqueType, addrElem, offsetX);
2687+
offsetX = b.i32_val(0);
2688+
}
2689+
if (!boundaryCheck.contains(rowDim)) {
2690+
baseHeight = b.i32_val(tileHeight);
2691+
// Use i8_ty as pitch is in number of bytes.
2692+
Value off = b.mul(offsetY, pitch);
2693+
addrElem = b.gep(ptr_ty(ctx, 1), i8_ty, addrElem, off);
2694+
offsetY = b.i32_val(0);
2695+
}
2696+
26772697
// Compose the matrix by stacking the name into vector.
26782698
Value storeVal = rewriter.create<LLVM::UndefOp>(
26792699
loc,
@@ -2684,7 +2704,7 @@ struct StoreOpToBlockIOConversion
26842704

26852705
auto newOp = rewriter.create<TritonGEN::Matrix2DBlockStoreOp>(
26862706
loc,
2687-
/*ptr*/ base,
2707+
/*ptr*/ addrElem,
26882708
/*base_width*/ baseWidth,
26892709
/*base_height*/ baseHeight,
26902710
/*base_pitch*/ pitch,

0 commit comments

Comments
 (0)