Skip to content

Commit 790a132

Browse files
authored
[mlir][amx] Increase op verifier test coverage (#155264)
Refactors and adds more test cases for invalid AMX operations.
1 parent db6a8f1 commit 790a132

File tree

1 file changed

+102
-8
lines changed

1 file changed

+102
-8
lines changed

mlir/test/Dialect/AMX/invalid.mlir

Lines changed: 102 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,142 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

3-
// -----
4-
5-
func.func @rowheight() {
3+
func.func @tile_row_height() {
64
// expected-error@+1 {{'amx.tile_zero' op bad row height: 17}}
75
%0 = amx.tile_zero : !amx.tile<17x16xbf16>
6+
return
87
}
98

109
// -----
1110

12-
func.func @colwidth() {
11+
func.func @tile_col_width() {
1312
// expected-error@+1 {{'amx.tile_zero' op bad column width: 65}}
1413
%0 = amx.tile_zero : !amx.tile<16x65xi8>
14+
return
1515
}
1616

1717
// -----
1818

19-
func.func @col4bytemultiple() {
19+
func.func @tile_col_4_byte_multiple() {
2020
// expected-error@+1 {{'amx.tile_zero' op bad column width: 5}}
2121
%0 = amx.tile_zero : !amx.tile<16x5xi8>
22+
return
2223
}
2324

2425
// -----
2526

26-
func.func @memtilesize(%arg0: memref<?x?xf32>) {
27+
func.func @load_base_tilesize(%arg0: memref<?x?xf32>) {
2728
%0 = arith.constant 0 : index
2829
// expected-error@+1 {{'amx.tile_load' op bad column width: 68}}
2930
%1 = amx.tile_load %arg0[%0, %0] : memref<?x?xf32> into !amx.tile<16x17xf32>
31+
return
32+
}
33+
34+
// -----
35+
36+
func.func @store_base_tilesize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x17xf32>) {
37+
%0 = arith.constant 0 : index
38+
// expected-error@+1 {{'amx.tile_store' op bad column width: 68}}
39+
amx.tile_store %arg0[%0, %0], %arg1 : memref<?x?xf32>, !amx.tile<16x17xf32>
40+
return
3041
}
3142

3243
// -----
3344

34-
func.func @memindexsize(%arg0: memref<?x?xf32>) {
45+
func.func @load_base_indexsize(%arg0: memref<?x?xf32>) {
3546
%0 = arith.constant 0 : index
3647
// expected-error@+1 {{'amx.tile_load' op requires 2 indices}}
3748
%1 = amx.tile_load %arg0[%0] : memref<?x?xf32> into !amx.tile<16x16xf32>
49+
return
50+
}
51+
52+
// -----
53+
54+
func.func @store_base_indexsize(%arg0: memref<?x?xf32>, %arg1: !amx.tile<16x16xf32>) {
55+
%0 = arith.constant 0 : index
56+
// expected-error@+1 {{'amx.tile_store' op requires 2 indices}}
57+
amx.tile_store %arg0[%0], %arg1 : memref<?x?xf32>, !amx.tile<16x16xf32>
58+
return
3859
}
3960

4061
// -----
4162

42-
func.func @multsize() {
63+
func.func @load_base_rank(%arg0: memref<?xf32>) {
64+
%0 = arith.constant 0 : index
65+
// expected-error@+1 {{'amx.tile_load' op requires at least 2D memref}}
66+
%1 = amx.tile_load %arg0[%0] : memref<?xf32> into !amx.tile<16x16xf32>
67+
return
68+
}
69+
70+
// -----
71+
72+
func.func @store_base_rank(%arg0: memref<?xf32>, %arg1: !amx.tile<16x16xf32>) {
73+
%0 = arith.constant 0 : index
74+
// expected-error@+1 {{'amx.tile_store' op requires at least 2D memref}}
75+
amx.tile_store %arg0[%0], %arg1 : memref<?xf32>, !amx.tile<16x16xf32>
76+
return
77+
}
78+
79+
// -----
80+
81+
func.func @load_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>) {
82+
%0 = arith.constant 0 : index
83+
// expected-error@+1 {{'amx.tile_load' op requires memref with unit innermost stride}}
84+
%1 = amx.tile_load %arg0[%0, %0]
85+
: memref<?x?xf32, strided<[?, ?]>> into !amx.tile<16x16xf32>
86+
return
87+
}
88+
89+
// -----
90+
91+
func.func @store_base_non_unit_stride(%arg0: memref<?x?xf32, strided<[?, ?]>>,
92+
%arg1: !amx.tile<16x16xf32>) {
93+
%0 = arith.constant 0 : index
94+
// expected-error@+1 {{'amx.tile_store' op requires memref with unit innermost stride}}
95+
amx.tile_store %arg0[%0, %0], %arg1
96+
: memref<?x?xf32, strided<[?, ?]>>, !amx.tile<16x16xf32>
97+
return
98+
}
99+
100+
// -----
101+
102+
func.func @mulf_shape() {
43103
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
44104
%1 = amx.tile_zero : !amx.tile<8x8xbf16>
45105
%2 = amx.tile_zero : !amx.tile<4x4xf32>
46106
// expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}}
47107
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<8x8xbf16>, !amx.tile<4x4xf32>
108+
return
109+
}
110+
111+
// -----
112+
113+
func.func @mulf_type_combination() {
114+
%0 = amx.tile_zero : !amx.tile<8x8xbf16>
115+
%1 = amx.tile_zero : !amx.tile<4x8xf16>
116+
%2 = amx.tile_zero : !amx.tile<8x4xf32>
117+
// expected-error@+1 {{'amx.tile_mulf' op unsupported type combination}}
118+
%3 = amx.tile_mulf %0, %1, %2 : !amx.tile<8x8xbf16>, !amx.tile<4x8xf16>, !amx.tile<8x4xf32>
119+
return
120+
}
121+
122+
// -----
123+
124+
func.func @muli_shape() {
125+
%0 = amx.tile_zero : !amx.tile<8x8xi8>
126+
%1 = amx.tile_zero : !amx.tile<8x8xi8>
127+
%2 = amx.tile_zero : !amx.tile<4x4xi32>
128+
// expected-error@+1 {{'amx.tile_muli' op bad mult shape: 4 x 4 x 2}}
129+
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x8xi8>, !amx.tile<8x8xi8>, !amx.tile<4x4xi32>
130+
return
131+
}
132+
133+
// -----
134+
135+
func.func @muli_type_combination() {
136+
%0 = amx.tile_zero : !amx.tile<8x16xi8>
137+
%1 = amx.tile_zero : !amx.tile<8x16xi32>
138+
%2 = amx.tile_zero : !amx.tile<2x2xi32>
139+
// expected-error@+1 {{'amx.tile_muli' op operand #1 must be tile of 8-bit signless integer values}}
140+
%3 = amx.tile_muli %0, %1, %2 : !amx.tile<8x16xi8>, !amx.tile<8x16xi32>, !amx.tile<2x2xi32>
141+
return
48142
}

0 commit comments

Comments
 (0)