Skip to content

Commit f744146

Browse files
[TritonGEN] Verify 2DBLOCKAddressPayload restriction on 2D block IO (#4526)
Ensure the restrictions below are satisfied when they are constant values: - Only 24 bits are supported for surface width/height/pitch field. Bits [31:24] are ignored by the hardware. - Surface width/pitch (encoded_value + 1) must be equal or greater than 64B. Signed-off-by: Whitney Tsang <[email protected]>
1 parent a29f867 commit f744146

File tree

3 files changed

+161
-9
lines changed

3 files changed

+161
-9
lines changed

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 141 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,54 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
209209

210210
// -----
211211

212+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
213+
%base_width = llvm.mlir.constant(16777216 : i32) : i32
214+
// expected-error @+1 {{'triton_gen.2Dblockload' op 2nd operand (base width) should be <= 24 bits}}
215+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
216+
llvm.return
217+
}
218+
219+
// -----
220+
221+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
222+
%base_width = llvm.mlir.constant(0 : i32) : i32
223+
// expected-error @+1 {{'triton_gen.2Dblockload' op 2nd operand (base width) should be >= 64}}
224+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
225+
llvm.return
226+
}
227+
228+
// -----
229+
230+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32) {
231+
%base_height = llvm.mlir.constant(16777216 : i32) : i32
232+
// expected-error @+1 {{'triton_gen.2Dblockload' op 3rd operand (base height) should be <= 24 bits}}
233+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
234+
llvm.return
235+
}
236+
237+
// -----
238+
239+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
240+
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
241+
// expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be <= 24 bits}}
242+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
243+
llvm.return
244+
}
245+
246+
// -----
247+
248+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
249+
%base_pitch = llvm.mlir.constant(0 : i32) : i32
250+
// expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be >= 64}}
251+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
252+
llvm.return
253+
}
254+
255+
// -----
256+
212257
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) {
213-
%base_width = llvm.mlir.constant(4 : i32) : i32
214-
%base_pitch = llvm.mlir.constant(2 : i32) : i32
258+
%base_width = llvm.mlir.constant(68 : i32) : i32
259+
%base_pitch = llvm.mlir.constant(64 : i32) : i32
215260
// expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be >= 2nd operand (base width)}}
216261
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
217262
llvm.return
@@ -273,11 +318,56 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height
273318
llvm.return
274319
}
275320

321+
322+
// -----
323+
324+
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
325+
%base_width = llvm.mlir.constant(16777216 : i32) : i32
326+
// expected-error @+1 {{'triton_gen.2Dblockstore' op 2nd operand (base width) should be <= 24 bits}}
327+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
328+
llvm.return
329+
}
330+
331+
// -----
332+
333+
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
334+
%base_width = llvm.mlir.constant(0 : i32) : i32
335+
// expected-error @+1 {{'triton_gen.2Dblockstore' op 2nd operand (base width) should be >= 64}}
336+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
337+
llvm.return
338+
}
339+
340+
// -----
341+
342+
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
343+
%base_height = llvm.mlir.constant(16777216 : i32) : i32
344+
// expected-error @+1 {{'triton_gen.2Dblockstore' op 3rd operand (base height) should be <= 24 bits}}
345+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
346+
llvm.return
347+
}
348+
349+
// -----
350+
351+
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
352+
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
353+
// expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be <= 24 bits}}
354+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
355+
llvm.return
356+
}
357+
358+
// -----
359+
360+
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
361+
%base_pitch = llvm.mlir.constant(0 : i32) : i32
362+
// expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be >= 64}}
363+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
364+
llvm.return
365+
}
276366
// -----
277367

278368
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
279-
%base_width = llvm.mlir.constant(4 : i32) : i32
280-
%base_pitch = llvm.mlir.constant(2 : i32) : i32
369+
%base_width = llvm.mlir.constant(68 : i32) : i32
370+
%base_pitch = llvm.mlir.constant(64 : i32) : i32
281371
// expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be >= 2nd operand (base width)}}
282372
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
283373
llvm.return
@@ -325,9 +415,54 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height
325415

326416
// -----
327417

418+
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
419+
%base_width = llvm.mlir.constant(16777216 : i32) : i32
420+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 2nd operand (base width) should be <= 24 bits}}
421+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
422+
llvm.return
423+
}
424+
425+
// -----
426+
427+
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
428+
%base_width = llvm.mlir.constant(0 : i32) : i32
429+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 2nd operand (base width) should be >= 64}}
430+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
431+
llvm.return
432+
}
433+
434+
// -----
435+
436+
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32) {
437+
%base_height = llvm.mlir.constant(16777216 : i32) : i32
438+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 3rd operand (base height) should be <= 24 bits}}
439+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
440+
llvm.return
441+
}
442+
443+
// -----
444+
445+
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
446+
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
447+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be <= 24 bits}}
448+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
449+
llvm.return
450+
}
451+
452+
// -----
453+
454+
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
455+
%base_pitch = llvm.mlir.constant(0 : i32) : i32
456+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be >= 64}}
457+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
458+
llvm.return
459+
}
460+
461+
// -----
462+
328463
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) {
329-
%base_width = llvm.mlir.constant(4 : i32) : i32
330-
%base_pitch = llvm.mlir.constant(2 : i32) : i32
464+
%base_width = llvm.mlir.constant(68 : i32) : i32
465+
%base_pitch = llvm.mlir.constant(64 : i32) : i32
331466
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be >= 2nd operand (base width)}}
332467
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
333468
llvm.return

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,23 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
2727
"Unexpected template parameter");
2828

2929
std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
30+
if (width) {
31+
if (*width > (1 << 24) - 1)
32+
return op->emitOpError("2nd operand (base width) should be <= 24 bits");
33+
if (*width < 64)
34+
return op->emitOpError("2nd operand (base width) should be >= 64");
35+
}
36+
std::optional<int64_t> height = getConstantIntValue(op.getBaseHeight());
37+
if (height)
38+
if (*height > (1 << 24) - 1)
39+
return op->emitOpError("3rd operand (base height) should be <= 24 bits");
3040
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
41+
if (pitch) {
42+
if (*pitch > (1 << 24) - 1)
43+
return op->emitOpError("4th operand (base pitch) should be <= 24 bits");
44+
if (*pitch < 64)
45+
return op->emitOpError("4th operand (base pitch) should be >= 64");
46+
}
3147
if (pitch && width && *pitch < *width)
3248
return op->emitOpError(
3349
"4th operand (base pitch) should be >= 2nd operand (base width)");

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,8 @@ struct PrefetchOpConversion
678678
masks[offset] = maskElems[i];
679679
}
680680

681-
Value baseWidth =
682-
b.i32_val(vBlocks * tileWidthInElem * (elemSizeInBits / 8));
681+
Value baseWidth = b.i32_val(
682+
std::max(64u, vBlocks * tileWidthInElem * (elemSizeInBits / 8)));
683683
Value rowStrideInBytes =
684684
getPitch(rewriter, op.getPtr(), baseAddrs, baseWidth, elemSizeInBits);
685685
if (!rowStrideInBytes)
@@ -1134,7 +1134,8 @@ struct LoadOpToBlockIOConversion
11341134
break;
11351135
}
11361136

1137-
Value baseWidth = b.i32_val(vBlocks * tileWidth * (elemSizeInBits / 8));
1137+
Value baseWidth =
1138+
b.i32_val(std::max(64u, vBlocks * tileWidth * (elemSizeInBits / 8)));
11381139
Value pitch = getPitch(rewriter, ptr, ptrs, baseWidth, elemSizeInBits);
11391140
if (!pitch)
11401141
return failure();

0 commit comments

Comments
 (0)