Skip to content

Commit 9431243

Browse files
[LoadStoreOpToLLVM] Check pitch HW restriction before generating 2d block load (#4829)
Prevent the generation of 2D block loads when the pitch does not meet hardware restrictions, which can be identified during compile time. BMG CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/16685703178 Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9837733 commit 9431243

File tree

5 files changed

+37
-33
lines changed

5 files changed

+37
-33
lines changed

python/test/unit/intel/test_block_load.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
@pytest.mark.xfail(not torch.xpu.get_device_capability()['has_subgroup_2d_block_io'],
1717
reason="Block loads not supported on this architecture")
1818
def test_block_load_dpas_layout(M, N, dtype_str, transpose, device, tmp_path: pathlib.Path):
19-
if transpose and N == 8:
20-
pytest.xfail("Pitch = 8 is not allowed by block IO")
21-
2219
# modify the layouts to ensure the correct OCL/SPIRV intrinsic is called for each datatype
2320
if dtype_str == "int8":
2421
A_width = 2

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
137137
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
138138
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
139139
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
140-
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
140+
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_12]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
141141
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
142142
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
143143
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32
@@ -200,7 +200,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32,
200200
// CHECK: %[[OFFSET_1:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][1] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
201201
// CHECK: %[[WIDTH_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][2] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
202202
// CHECK: %[[HEIGHT_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][3] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
203-
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
203+
// CHECK: %[[ROW_STRIDE_i64:.*]] = llvm.extractvalue %[[VAL_11]][4] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
204204
// CHECK: %[[COL_STRIDE_i64:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][5] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
205205
// CHECK: %[[BASE:.*]] = llvm.extractvalue %[[BLOCK_POINTER]][6] : !llvm.struct<(i32, i32, i64, i64, i64, i64, ptr<1>)>
206206
// CHECK: %[[ROW_STRIDE_i32:.*]] = llvm.trunc %[[ROW_STRIDE_i64]] : i64 to i32

test/TritonIntelGPU/subgroup-2d-block-io.mlir

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
88
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
99
%0 = tt.get_program_id x : i32
1010
%M_i64 = arith.constant 16 : i64
11-
%N_i64 = arith.constant 16 : i64
11+
%N_i64 = arith.constant 64 : i64
1212
%c1_i64 = arith.constant 1 : i64
1313
%c0_i32 = arith.constant 0 : i32
1414

@@ -29,7 +29,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
2929
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
3030
%0 = tt.get_program_id x : i32
3131
%M_i64 = arith.constant 16 : i64
32-
%N_i64 = arith.constant 16 : i64
32+
%N_i64 = arith.constant 64 : i64
3333
%c1_i64 = arith.constant 1 : i64
3434
%c0_i32 = arith.constant 0 : i32
3535

@@ -50,7 +50,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
5050
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
5151
%0 = tt.get_program_id x : i32
5252
%M_i64 = arith.constant 16 : i64
53-
%N_i64 = arith.constant 16 : i64
53+
%N_i64 = arith.constant 64 : i64
5454
%c1_i64 = arith.constant 1 : i64
5555
%c0_i32 = arith.constant 0 : i32
5656

@@ -71,7 +71,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
7171
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
7272
%0 = tt.get_program_id x : i32
7373
%M_i64 = arith.constant 16 : i64
74-
%N_i64 = arith.constant 16 : i64
74+
%N_i64 = arith.constant 64 : i64
7575
%c1_i64 = arith.constant 1 : i64
7676
%c0_i32 = arith.constant 0 : i32
7777

@@ -92,7 +92,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
9292
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
9393
%0 = tt.get_program_id x : i32
9494
%M_i64 = arith.constant 32 : i64
95-
%N_i64 = arith.constant 16 : i64
95+
%N_i64 = arith.constant 64 : i64
9696
%c1_i64 = arith.constant 1 : i64
9797
%c0_i32 = arith.constant 0 : i32
9898

@@ -113,7 +113,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
113113
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
114114
%0 = tt.get_program_id x : i32
115115
%M_i64 = arith.constant 32 : i64
116-
%N_i64 = arith.constant 16 : i64
116+
%N_i64 = arith.constant 64 : i64
117117
%c1_i64 = arith.constant 1 : i64
118118
%c0_i32 = arith.constant 0 : i32
119119

@@ -134,7 +134,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
134134
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
135135
%0 = tt.get_program_id x : i32
136136
%M_i64 = arith.constant 32 : i64
137-
%N_i64 = arith.constant 16 : i64
137+
%N_i64 = arith.constant 64 : i64
138138
%c1_i64 = arith.constant 1 : i64
139139
%c0_i32 = arith.constant 0 : i32
140140

@@ -155,7 +155,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
155155
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
156156
%0 = tt.get_program_id x : i32
157157
%M_i64 = arith.constant 32 : i64
158-
%N_i64 = arith.constant 16 : i64
158+
%N_i64 = arith.constant 64 : i64
159159
%c1_i64 = arith.constant 1 : i64
160160
%c0_i32 = arith.constant 0 : i32
161161

@@ -176,7 +176,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
176176
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
177177
%0 = tt.get_program_id x : i32
178178
%M_i64 = arith.constant 64 : i64
179-
%N_i64 = arith.constant 16 : i64
179+
%N_i64 = arith.constant 64 : i64
180180
%c1_i64 = arith.constant 1 : i64
181181
%c0_i32 = arith.constant 0 : i32
182182

@@ -197,7 +197,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
197197
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
198198
%0 = tt.get_program_id x : i32
199199
%M_i64 = arith.constant 64 : i64
200-
%N_i64 = arith.constant 16 : i64
200+
%N_i64 = arith.constant 64 : i64
201201
%c1_i64 = arith.constant 1 : i64
202202
%c0_i32 = arith.constant 0 : i32
203203

@@ -218,7 +218,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
218218
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
219219
%0 = tt.get_program_id x : i32
220220
%M_i64 = arith.constant 64 : i64
221-
%N_i64 = arith.constant 16 : i64
221+
%N_i64 = arith.constant 64 : i64
222222
%c1_i64 = arith.constant 1 : i64
223223
%c0_i32 = arith.constant 0 : i32
224224

@@ -239,7 +239,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
239239
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
240240
%0 = tt.get_program_id x : i32
241241
%M_i64 = arith.constant 64 : i64
242-
%N_i64 = arith.constant 32 : i64
242+
%N_i64 = arith.constant 64 : i64
243243
%c1_i64 = arith.constant 1 : i64
244244
%c0_i32 = arith.constant 0 : i32
245245

@@ -260,7 +260,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
260260
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
261261
%0 = tt.get_program_id x : i32
262262
%M_i64 = arith.constant 64 : i64
263-
%N_i64 = arith.constant 32 : i64
263+
%N_i64 = arith.constant 64 : i64
264264
%c1_i64 = arith.constant 1 : i64
265265
%c0_i32 = arith.constant 0 : i32
266266

@@ -281,7 +281,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
281281
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
282282
%0 = tt.get_program_id x : i32
283283
%M_i64 = arith.constant 64 : i64
284-
%N_i64 = arith.constant 32 : i64
284+
%N_i64 = arith.constant 64 : i64
285285
%c1_i64 = arith.constant 1 : i64
286286
%c0_i32 = arith.constant 0 : i32
287287

@@ -302,7 +302,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
302302
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
303303
%0 = tt.get_program_id x : i32
304304
%M_i64 = arith.constant 128 : i64
305-
%N_i64 = arith.constant 32 : i64
305+
%N_i64 = arith.constant 64 : i64
306306
%c1_i64 = arith.constant 1 : i64
307307
%c0_i32 = arith.constant 0 : i32
308308

@@ -323,7 +323,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
323323
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
324324
%0 = tt.get_program_id x : i32
325325
%M_i64 = arith.constant 256 : i64
326-
%N_i64 = arith.constant 32 : i64
326+
%N_i64 = arith.constant 64 : i64
327327
%c1_i64 = arith.constant 1 : i64
328328
%c0_i32 = arith.constant 0 : i32
329329

@@ -344,7 +344,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
344344
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
345345
%0 = tt.get_program_id x : i32
346346
%M_i64 = arith.constant 256 : i64
347-
%N_i64 = arith.constant 32 : i64
347+
%N_i64 = arith.constant 64 : i64
348348
%c1_i64 = arith.constant 1 : i64
349349
%c0_i32 = arith.constant 0 : i32
350350

@@ -365,7 +365,7 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
365365
tt.func public @subgroup_2d_block_load(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16: i32}, %arg3: !tt.ptr<f16> {tt.divisibility = 16: i32}) attributes {noinline = false} {
366366
%0 = tt.get_program_id x : i32
367367
%M_i64 = arith.constant 256 : i64
368-
%N_i64 = arith.constant 32 : i64
368+
%N_i64 = arith.constant 64 : i64
369369
%c1_i64 = arith.constant 1 : i64
370370
%c0_i32 = arith.constant 0 : i32
371371

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h"
1515
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h"
16+
#include "intel/include/Utils/Utility.h"
1617
#include "triton/Tools/LinearLayout.h"
1718
#include <optional>
1819
#include <triton/Tools/Sys/GetEnv.hpp>
@@ -1540,6 +1541,21 @@ struct LoadOpToBlockIOConversion
15401541
pitch = b.trunc(i32_ty, colStride);
15411542
std::swap(baseWidth, baseHeight);
15421543
}
1544+
// HW requires the pitch to be at least 64 bytes.
1545+
std::function<Value(Value)> skipTrunc = [&](Value v) {
1546+
if (dyn_cast_or_null<LLVM::TruncOp>(v.getDefiningOp()))
1547+
return skipTrunc(v.getDefiningOp()->getOperand(0));
1548+
return v;
1549+
};
1550+
if (Operation *op = skipTrunc(pitch).getDefiningOp()) {
1551+
std::optional<int64_t> pitchConst =
1552+
mlir::triton::intel::getFoldedConstantValue(op);
1553+
if (pitchConst.has_value()) {
1554+
if ((*pitchConst * elemSizeInBits / 8) < 64)
1555+
return failure();
1556+
}
1557+
}
1558+
15431559
baseWidth = b.trunc(i32_ty, baseWidth);
15441560
baseHeight = b.trunc(i32_ty, baseHeight);
15451561

third_party/intel/lib/Utils/Utility.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,7 @@ std::optional<int64_t> getFoldedConstantValue(Operation *op) {
119119
if (results.size() != 1)
120120
return std::nullopt;
121121

122-
std::optional<int64_t> intAttr = getIntAttr(results[0]);
123-
if (intAttr.has_value())
124-
return intAttr.value();
125-
126-
auto val = cast<Value>(results[0]);
127-
auto constOp = val.getDefiningOp<arith::ConstantOp>();
128-
if (!constOp)
129-
return std::nullopt;
130-
131-
return getIntAttr(constOp.getValue());
122+
return getConstantIntValue(results[0]);
132123
}
133124

134125
bool isConstant(Value val, int64_t expected) {

0 commit comments

Comments
 (0)