Skip to content

Commit ee12c53

Browse files
committed
[mlir][amx] Prevent crash on invalid tile element type
Fixes AMX tile type parser to prevent crashes on invalid element type.
1 parent c0f7091 commit ee12c53

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

mlir/lib/Dialect/AMX/IR/AMXDialect.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) {
271271
if (parser.parseGreater())
272272
return nullptr;
273273

274-
return TileType::get(shape, elementType);
274+
return TileType::getChecked(
275+
[&] { return parser.emitError(parser.getNameLoc()); }, shape,
276+
elementType);
275277
}
276278

277279
void amx::TileType::print(AsmPrinter &os) const {

mlir/test/Dialect/AMX/invalid.mlir

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,22 @@ func.func @tile_col_width() {
1616

1717
// -----
1818

19+
func.func @tile_element_type() {
20+
// expected-error@+1 {{failed to verify 'elementType'}}
21+
%0 = amx.tile_zero : !amx.tile<8x8xi16>
22+
return
23+
}
24+
25+
// -----
26+
27+
func.func @tile_rank() {
28+
// expected-error@+1 {{'amx.tile_zero' op result #0 must be tile of}}
29+
%0 = amx.tile_zero : !amx.tile<32xi8>
30+
return
31+
}
32+
33+
// -----
34+
1935
func.func @tile_col_4_byte_multiple() {
2036
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
2137
%0 = amx.tile_zero : !amx.tile<16x5xi8>
@@ -24,7 +40,7 @@ func.func @tile_col_4_byte_multiple() {
2440

2541
// -----
2642

27-
func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
43+
func.func @load_base_tile_size(%arg0: memref<?x?xf32>) {
2844
%0 = arith.constant 0 : index
2945
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
3046
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
@@ -33,7 +49,7 @@ func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
3349

3450
// -----
3551

36-
func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
52+
func.func @store_base_tile_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
3753
%0 = arith.constant 0 : index
3854
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
3955
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
@@ -42,7 +58,7 @@ func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf3
4258

4359
// -----
4460

45-
func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
61+
func.func @load_base_index_size(%arg0: memref<?x?xf32>) {
4662
%0 = arith.constant 0 : index
4763
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
4864
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
@@ -51,7 +67,7 @@ func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
5167

5268
// -----
5369

54-
func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
70+
func.func @store_base_index_size(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
5571
%0 = arith.constant 0 : index
5672
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
5773
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>

0 commit comments

Comments
 (0)