@@ -164,3 +164,48 @@ func.func @wmma_unsupported_k(%arg0 : vector<8xf16>, %arg1 : vector<8xf32>) {
164164 amdgpu.wmma 16 x16 x16 %arg0 * %arg0 + %arg1 : vector <8 xf16 >, vector <8 xf16 >, vector <8 xf32 >
165165 return
166166}
167+
168+ // -----
169+
170+ func.func @scaled_wmma_wrong_output_length (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <16 xf32 >,
171+ %arg2 : vector <4 xf8 E8 M0 FNU>) {
172+ // expected-error@below {{'amdgpu.scaled_wmma' op expected output vector of length 8 but got 16}}
173+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg2 [0 ] * %arg0 ) * (%arg2 [0 ] * %arg0 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <16 xf32 >
174+ return
175+ }
176+
177+ func.func @scaled_wmma_16x16_wrong_sourceA_length (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
178+ %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
179+ // expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceA must have 64 elements but got 128}}
180+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [0 ] * %arg0 ) * (%arg3 [0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E8 M0 FNU>, vector <64 xf4 E2 M1 FN>, vector <8 xf32 >
181+ return
182+ }
183+
184+ func.func @scaled_wmma_16x16_wrong_sourceB_length (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <128 xf4 E2 M1 FN>,
185+ %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>) {
186+ // expected-error@below {{'amdgpu.scaled_wmma' op for 16x16x128, sourceB must have 64 elements but got 128}}
187+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [0 ] * %arg0 ) * (%arg3 [0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E8 M0 FNU>, vector <128 xf4 E2 M1 FN>, vector <8 xf32 >
188+ return
189+ }
190+
191+ func.func @scaled_wmma_32x16_wrong_sourceA_length (%arg0 : vector <64 xf4 E2 M1 FN>, %arg1 : vector <64 xf4 E2 M1 FN>,
192+ %arg2 : vector <16 xf32 >, %arg3 : vector <4 xf8 E4 M3 FN>) {
193+ // expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceA must have 128 elements but got 64}}
194+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [0 ] * %arg0 ) * (%arg3 [0 ] * %arg1 ) + %arg2 : vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf4 E2 M1 FN>, vector <16 xf32 >
195+ return
196+ }
197+
198+ func.func @scaled_wmma_32x16_wrong_sourceB_length (%arg0 : vector <128 xf4 E2 M1 FN>, %arg1 : vector <128 xf4 E2 M1 FN>,
199+ %arg2 : vector <16 xf32 >, %arg3 : vector <4 xf8 E4 M3 FN>) {
200+ // expected-error@below {{'amdgpu.scaled_wmma' op for 32x16x128, sourceB must have 64 elements but got 128}}
201+ %0 = amdgpu.scaled_wmma 32 x16 x128 (%arg3 [0 ] * %arg0 ) * (%arg3 [0 ] * %arg1 ) + %arg2 : vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <4 xf8 E4 M3 FN>, vector <128 xf4 E2 M1 FN>, vector <16 xf32 >
202+ return
203+ }
204+
205+ func.func @scaled_wmma_invalid_type_combination (%arg0 : vector <64 xf8 E4 M3 FN>, %arg1 : vector <64 xf6 E2 M3 FN>,
206+ %arg2 : vector <8 xf32 >, %arg3 : vector <4 xf8 E8 M0 FNU>,
207+ %arg4 : vector <4 xf8 E4 M3 FN>) {
208+ // expected-error@below {{'amdgpu.scaled_wmma' op invalid combination of matrix and scale types}}
209+ %0 = amdgpu.scaled_wmma 16 x16 x128 (%arg3 [0 ] * %arg0 ) * (%arg4 [0 ] * %arg1 ) + %arg2 : vector <4 xf8 E8 M0 FNU>, vector <64 xf8 E4 M3 FN>, vector <4 xf8 E4 M3 FN>, vector <64 xf6 E2 M3 FN>, vector <8 xf32 >
210+ return
211+ }
0 commit comments