Skip to content

Commit f4764a9

Browse files
[TritonGEN] Update 2D block verifier (#4644)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 21b69a0 commit f4764a9

File tree

3 files changed

+44
-80
lines changed

3 files changed

+44
-80
lines changed

test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path_invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
55
%c1_i64 = arith.constant 1 : i64
66
%c0_i32 = arith.constant 0 : i32
77
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xf32>>
8-
// expected-error @+2 {{tile_width for 32 bit elements should be equal to 8 or 16}}
8+
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
99
// expected-error @+1 {{failed to legalize operation 'ttig.prefetch'}}
1010
ttig.prefetch %22 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<tensor<2x32xf32>>
1111
tt.return
@@ -19,7 +19,7 @@ module attributes {"ttig.support_sg_2d_block", "ttig.support_dpas", "ttg.num-war
1919
%c1_i64 = arith.constant 1 : i64
2020
%c0_i32 = arith.constant 0 : i32
2121
%22 = tt.make_tensor_ptr %arg0, [%arg1, %arg1], [%arg1, %c1_i64], [%arg2, %c0_i32] {order = array<i32: 1, 0>} : <tensor<2x32xf32>>
22-
// expected-error @+2 {{expecting tile_width to be between 1 and 16}}
22+
// expected-error @+2 {{expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
2323
// expected-error @+1 {{failed to legalize operation 'tt.load'}}
2424
%res = tt.load %22 {DotIdx = 0 : i32, boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<2x32xf32>>
2525
tt.return

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
170170
// -----
171171

172172
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
173-
// expected-error @+1 {{'triton_gen.2Dblockload' op tile_width * v_blocks should be less than or equal to 64 for 8 bit elements}}
173+
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
174174
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=4, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi16>
175175
llvm.return
176176
}
@@ -523,23 +523,15 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
523523
// -----
524524

525525
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
526-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 16 bit elements should be equal to 16}}
527-
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
528-
llvm.return
529-
}
530-
531-
// -----
532-
533-
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
534-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 8 bit elements should be equal to 16 or 32}}
535-
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
526+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting elem_size_in_bits * tile_width * v_blocks <= 512}}
527+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
536528
llvm.return
537529
}
538530

539531
// -----
540532

541533
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
542-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op tile_width for 32 bit elements should be equal to 8 or 16}}
543-
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
534+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting tile_width to be between 4 and 64}}
535+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=1, tile_height=32, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
544536
llvm.return
545537
}

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 37 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,40 @@ static LogicalResult verify2DBlockAddressPayloadRestriction(Op op) {
8686
return success();
8787
}
8888

89+
template <typename Op> static LogicalResult verify2DBlockHWRestriction(Op op) {
90+
static_assert(llvm::is_one_of<Op, TritonGEN::Matrix2DBlockLoadOp,
91+
TritonGEN::Matrix2DBlockPrefetchOp>::value,
92+
"Unexpected template parameter");
93+
94+
unsigned elemSizeInBits = op.getElemSizeInBits();
95+
uint32_t tileWidth = op.getTileWidth();
96+
uint32_t vBlocks = op.getVBlocks();
97+
if (elemSizeInBits * tileWidth * vBlocks > 512)
98+
return op->emitOpError(
99+
"expecting elem_size_in_bits * tile_width * v_blocks <= 512");
100+
101+
switch (elemSizeInBits) {
102+
case 8:
103+
if (tileWidth < 4)
104+
return op->emitOpError("expecting tile_width to be between 4 and 64");
105+
break;
106+
case 16:
107+
if (tileWidth < 2 || tileWidth > 32)
108+
return op.emitOpError("expecting tile_width to be between 2 and 32");
109+
break;
110+
case 32:
111+
if (tileWidth > 16)
112+
return op.emitOpError("expecting tile_width to be between 1 and 16");
113+
if (vBlocks == 4)
114+
return op->emitOpError("v_blocks for 32 bit elements should be 1 or 2");
115+
break;
116+
default:
117+
llvm_unreachable("unexpected element size");
118+
}
119+
120+
return success();
121+
}
122+
89123
//===----------------------------------------------------------------------===//
90124
// gen.matrix.dpas
91125
//===----------------------------------------------------------------------===//
@@ -202,49 +236,8 @@ verify2DBlockLoadHWRestriction(TritonGEN::Matrix2DBlockLoadOp op) {
202236
return op.emitOpError(
203237
"transpose and vnni_transform are mutually exclusive");
204238

205-
if (!op.getTranspose() && !op.getVnniTransform()) {
206-
uint32_t tileWidth = op.getTileWidth();
207-
uint32_t vBlocks = op.getVBlocks();
208-
switch (op.getElemSizeInBits()) {
209-
case 8:
210-
if (tileWidth < 4 || tileWidth > 64)
211-
return op.emitOpError("expecting tile_width to be between 4 and 64");
212-
if (tileWidth * vBlocks > 64)
213-
return op.emitOpError(
214-
"tile_width * v_blocks should be less than or equal "
215-
"to 64 for 8 bit elements");
216-
break;
217-
case 16:
218-
if (tileWidth < 2 || tileWidth > 32)
219-
return op.emitOpError("expecting tile_width to be between 2 and 32");
220-
if (tileWidth * vBlocks > 32)
221-
return op.emitOpError(
222-
"tile_width * v_blocks should be less than or equal "
223-
"to 32 for 16 bit elements");
224-
break;
225-
case 32:
226-
if (tileWidth < 1 || tileWidth > 16)
227-
return op.emitOpError("expecting tile_width to be between 1 and 16");
228-
if (vBlocks != 1 && vBlocks != 2)
229-
return op.emitOpError("expecting v_blocks to be 1 or 2");
230-
if (tileWidth * vBlocks > 16)
231-
return op.emitOpError(
232-
"tile_width * v_blocks should be less than or equal "
233-
"to 16 for 32 bit elements");
234-
break;
235-
case 64:
236-
if (tileWidth < 1 || tileWidth > 8)
237-
return op.emitOpError("expecting tile_width to be between 1 and 8");
238-
if (vBlocks != 1)
239-
return op.emitOpError("expecting v_blocks to be 1");
240-
break;
241-
default:
242-
return op.emitOpError(
243-
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
244-
}
245-
246-
return success();
247-
}
239+
if (!op.getTranspose() && !op.getVnniTransform())
240+
return verify2DBlockHWRestriction(op);
248241

249242
if (op.getTranspose()) {
250243
assert(!op.getVnniTransform() &&
@@ -411,26 +404,5 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
411404
if (verify2DBlockAddressPayloadRestriction(*this).failed())
412405
return failure();
413406

414-
uint32_t tileWidth = getTileWidth();
415-
switch (getElemSizeInBits()) {
416-
case 8:
417-
if (tileWidth != 16 && tileWidth != 32)
418-
return emitOpError("tile_width for 8 bit elements should be equal to "
419-
"16 or 32");
420-
break;
421-
case 16:
422-
if (tileWidth != 16)
423-
return emitOpError("tile_width for 16 bit elements should be equal "
424-
"to 16");
425-
break;
426-
case 32:
427-
if (tileWidth != 8 && tileWidth != 16)
428-
return emitOpError(
429-
"tile_width for 32 bit elements should be equal to 8 or 16");
430-
break;
431-
default:
432-
llvm_unreachable("unexpected element size");
433-
}
434-
435-
return success();
407+
return verify2DBlockHWRestriction(*this);
436408
}

0 commit comments

Comments
 (0)