Skip to content

Commit dcc6db9

Browse files
[TritonGEN] Update 2D block address payload restrictions (#4607)
This PR finishes adding verifiers for 2D block address payload restrictions. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 74d5059 commit dcc6db9

File tree

2 files changed

+110
-73
lines changed

2 files changed

+110
-73
lines changed

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -155,31 +155,15 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
155155

156156
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
157157
// expected-error @+1 {{'triton_gen.2Dblockload' op transpose and vnni_transform are mutually exclusive}}
158-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=16, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<4xi32>
159-
llvm.return
160-
}
161-
162-
// -----
163-
164-
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
165-
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be between 1 and 32}}
166-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<64xi16>
158+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=16, tile_height=8, v_blocks=1, transpose=true, vnni_transform=true, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<2xi32>
167159
llvm.return
168160
}
169161

170162
// -----
171163

172164
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
173165
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_width to be between 4 and 64}}
174-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=128, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi32>
175-
llvm.return
176-
}
177-
178-
// -----
179-
180-
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
181-
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting v_blocks to be 1, 2, or 4}}
182-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=6, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<48xi16>
166+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=1, tile_height=32, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<1xi16>
183167
llvm.return
184168
}
185169

@@ -210,7 +194,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
210194
// -----
211195

212196
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
213-
%base_width = llvm.mlir.constant(16777216 : i32) : i32
197+
%base_width = llvm.mlir.constant(16777217 : i32) : i32
214198
// expected-error @+1 {{'triton_gen.2Dblockload' op 2nd operand (base width) should be <= 24 bits}}
215199
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
216200
llvm.return
@@ -227,8 +211,17 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %base_pitch
227211

228212
// -----
229213

214+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
215+
%base_width = llvm.mlir.constant(65 : i32) : i32
216+
// expected-error @+1 {{'triton_gen.2Dblockload' op 2nd operand (base width) should be aligned to MAX(4, element_size)}}
217+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
218+
llvm.return
219+
}
220+
221+
// -----
222+
230223
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32) {
231-
%base_height = llvm.mlir.constant(16777216 : i32) : i32
224+
%base_height = llvm.mlir.constant(16777217 : i32) : i32
232225
// expected-error @+1 {{'triton_gen.2Dblockload' op 3rd operand (base height) should be <= 24 bits}}
233226
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
234227
llvm.return
@@ -237,7 +230,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_pitch :
237230
// -----
238231

239232
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
240-
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
233+
%base_pitch = llvm.mlir.constant(16777217 : i32) : i32
241234
// expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be <= 24 bits}}
242235
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
243236
llvm.return
@@ -254,6 +247,15 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
254247

255248
// -----
256249

250+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
251+
%base_pitch = llvm.mlir.constant(65 : i32) : i32
252+
// expected-error @+1 {{'triton_gen.2Dblockload' op 4th operand (base pitch) should be a multiple of 16 bytes}}
253+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
254+
llvm.return
255+
}
256+
257+
// -----
258+
257259
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y : i32) {
258260
%base_width = llvm.mlir.constant(68 : i32) : i32
259261
%base_pitch = llvm.mlir.constant(64 : i32) : i32
@@ -264,6 +266,24 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_height : i32, %x : i32, %y
264266

265267
// -----
266268

269+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %y : i32) {
270+
%x = llvm.mlir.constant(1 : i32) : i32
271+
// expected-error @+1 {{'triton_gen.2Dblockload' op 5th operand (x) should be a multiple of 4 for 8 bit elements}}
272+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
273+
llvm.return
274+
}
275+
276+
// -----
277+
278+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %y : i32) {
279+
%x = llvm.mlir.constant(1 : i32) : i32
280+
// expected-error @+1 {{'triton_gen.2Dblockload' op 5th operand (x) should be a multiple of 2 for 16 bit elements}}
281+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=16, tile_width=32, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<16xi16>
282+
llvm.return
283+
}
284+
285+
// -----
286+
267287
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
268288
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting 'elem_size_in_bits' to be 8, 16, or 32}}
269289
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=64, tile_width=4, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -273,8 +293,24 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
273293
// -----
274294

275295
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
276-
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}}
277-
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=24, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<24xi16>
296+
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile shape to be power of two}}
297+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=48, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<12xi16>
298+
llvm.return
299+
}
300+
301+
// -----
302+
303+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
304+
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_width to be between 1-64}}
305+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=128, tile_height=8, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<32xi16>
306+
llvm.return
307+
}
308+
309+
// -----
310+
311+
llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
312+
// expected-error @+1 {{'triton_gen.2Dblockload' op expecting tile_height to be between 1-32}}
313+
%0 = triton_gen.2Dblockload %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, transpose=false, vnni_transform=false, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32) -> vector<64xi16>
278314
llvm.return
279315
}
280316

@@ -298,7 +334,7 @@ llvm.func @matrix_2Dblockload(%ptr : !llvm.ptr, %base_width : i32, %base_height
298334

299335
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<64xi8>) {
300336
// expected-error @+1 {{'triton_gen.2Dblockstore' op expecting tile_height to be between 1 and 8}}
301-
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<64xi8>)
337+
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=32, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<64xi8>)
302338
llvm.return
303339
}
304340

@@ -322,7 +358,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height
322358
// -----
323359

324360
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
325-
%base_width = llvm.mlir.constant(16777216 : i32) : i32
361+
%base_width = llvm.mlir.constant(16777217 : i32) : i32
326362
// expected-error @+1 {{'triton_gen.2Dblockstore' op 2nd operand (base width) should be <= 24 bits}}
327363
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
328364
llvm.return
@@ -340,7 +376,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_height : i32, %base_pitch
340376
// -----
341377

342378
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
343-
%base_height = llvm.mlir.constant(16777216 : i32) : i32
379+
%base_height = llvm.mlir.constant(16777217 : i32) : i32
344380
// expected-error @+1 {{'triton_gen.2Dblockstore' op 3rd operand (base height) should be <= 24 bits}}
345381
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
346382
llvm.return
@@ -349,7 +385,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_pitch
349385
// -----
350386

351387
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32, %stored_val : vector<8xi8>) {
352-
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
388+
%base_pitch = llvm.mlir.constant(16777217 : i32) : i32
353389
// expected-error @+1 {{'triton_gen.2Dblockstore' op 4th operand (base pitch) should be <= 24 bits}}
354390
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<8xi8>)
355391
llvm.return
@@ -383,14 +419,6 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height
383419

384420
// -----
385421

386-
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<64xi8>) {
387-
// expected-error @+1 {{'triton_gen.2Dblockstore' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}}
388-
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=32, tile_height=6, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<64xi8>)
389-
llvm.return
390-
}
391-
392-
// -----
393-
394422
llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32, %stored_val : vector<4xi8>) {
395423
// expected-error @+1 {{'triton_gen.2Dblockstore' op tile_width for 8 bit elements should be equal to 16 or 32}}
396424
triton_gen.2Dblockstore %ptr, %base_width, %base_height, %base_pitch, %x, %y, %stored_val {elem_size_in_bits=8, tile_width=8, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32, vector<4xi8>)
@@ -416,7 +444,7 @@ llvm.func @matrix_2Dblockstore(%ptr : !llvm.ptr, %base_width : i32, %base_height
416444
// -----
417445

418446
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
419-
%base_width = llvm.mlir.constant(16777216 : i32) : i32
447+
%base_width = llvm.mlir.constant(16777217 : i32) : i32
420448
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 2nd operand (base width) should be <= 24 bits}}
421449
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
422450
llvm.return
@@ -434,7 +462,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_height : i32, %base_pi
434462
// -----
435463

436464
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_pitch : i32, %x : i32, %y : i32) {
437-
%base_height = llvm.mlir.constant(16777216 : i32) : i32
465+
%base_height = llvm.mlir.constant(16777217 : i32) : i32
438466
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 3rd operand (base height) should be <= 24 bits}}
439467
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
440468
llvm.return
@@ -443,7 +471,7 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_pit
443471
// -----
444472

445473
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %x : i32, %y : i32) {
446-
%base_pitch = llvm.mlir.constant(16777216 : i32) : i32
474+
%base_pitch = llvm.mlir.constant(16777217 : i32) : i32
447475
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op 4th operand (base pitch) should be <= 24 bits}}
448476
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
449477
llvm.return
@@ -479,16 +507,16 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
479507
// -----
480508

481509
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
482-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting tile_height to be 1, 2, 4, 8, 16, or 32}}
510+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting tile_height to be between 1-32}}
483511
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=64, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
484512
llvm.return
485513
}
486514

487515
// -----
488516

489517
llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_height : i32, %base_pitch : i32, %x : i32, %y : i32) {
490-
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting v_blocks to be 1, 2, 4, or 8}}
491-
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=6, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
518+
// expected-error @+1 {{'triton_gen.2Dblockprefetch' op expecting v_blocks to be between 1-4}}
519+
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=8, tile_width=32, tile_height=8, v_blocks=8, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
492520
llvm.return
493521
}
494522

0 commit comments

Comments
 (0)