1- // RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 --allow-unregistered-dialect | FileCheck %s
1+ // RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1250 \
2+ // RUN: --split-input-file --verify-diagnostics | FileCheck %s
23
34// CHECK-LABEL: @wmma_k4
45func.func @wmma_k4 (%arg0 : vector <2 xf32 >, %arg1 : vector <8 xf32 >) {
56 // CHECK: rocdl.wmma.f32.16x16x4.f32 %arg0, %arg0, %arg1
67 amdgpu.wmma 16 x16 x4 %arg0 * %arg0 + %arg1 : vector <2 xf32 >, vector <2 xf32 >, vector <8 xf32 >
7- func. return
8+ return
89}
910
1011// CHECK-LABEL: @wmma_k32
@@ -22,7 +23,7 @@ func.func @wmma_k32(%arg0 : vector<16xf16>, %arg1 : vector<16xbf16>, %arg2 : vec
2223 // CHECK: rocdl.wmma.bf16.16x16x32.bf16 {{.*}}, {{.*}}, {{.*}}, {{.*}} : (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1)
2324 amdgpu.wmma 16 x16 x32 %arg1 * %arg1 + %arg4 : vector <16 xbf16 >, vector <16 xbf16 >, vector <8 xbf16 >
2425
25- func. return
26+ return
2627}
2728
2829// CHECK-LABEL: @wmma_k64
@@ -55,7 +56,7 @@ func.func @wmma_k64(%arg0 : vector<32xi8>, %arg1 : vector<32xf8E4M3FN>, %arg2 :
5556 // CHECK: rocdl.wmma.f16.16x16x64.bf8_fp8 {{.*}}, {{.*}}, %arg5, {{.*}} : (vector<8xi32>, vector<8xi32>, vector<8xf16>, i1)
5657 amdgpu.wmma 16 x16 x64 %arg2 * %arg1 + %arg5 : vector <32 xf8 E5 M2 >, vector <32 xf8 E4 M3 FN>, vector <8 xf16 >
5758
58- func. return
59+ return
5960}
6061
6162// CHECK-LABEL: @wmma_k128
@@ -85,5 +86,14 @@ func.func @wmma_k128(%arg0 : vector<64xf8E4M3FN>, %arg1 : vector<64xf8E5M2>,
8586 // CHECK: rocdl.wmma.f16.16x16x128.bf8_fp8 {{.*}}, {{.*}}, %arg3, {{.*}} : (vector<16xi32>, vector<16xi32>, vector<8xf16>, i1)
8687 amdgpu.wmma 16 x16 x128 %arg1 * %arg0 + %arg3 : vector <64 xf8 E5 M2 >, vector <64 xf8 E4 M3 FN>, vector <8 xf16 >
8788
88- func.return
89+ return
90+ }
91+
92+ // -----
93+
94+ func.func @wmma_unsupported_k (%arg0 : vector <8 xf16 >, %arg1 : vector <8 xf32 >) {
95+ // expected-error@below {{'amdgpu.wmma' op no intrinsic matching WMMA on the given chipset}}
96+ // expected-error@below {{failed to legalize operation 'amdgpu.wmma'}}
97+ amdgpu.wmma 16 x16 x16 %arg0 * %arg0 + %arg1 : vector <8 xf16 >, vector <8 xf16 >, vector <8 xf32 >
98+ return
8999}
0 commit comments