|
1 | | -// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s |
| 1 | +// RUN: mlir-opt %s --convert-amdgpu-to-rocdl=chipset=gfx1100 --allow-unregistered-dialect | FileCheck %s |
| 2 | + |
2 | 3 | // CHECK-LABEL: @wmma_to_rocdl |
3 | 4 | func.func @wmma_to_rocdl(%arg0 : vector<16xf16>, %arg1 : vector<8xf32>, %arg2 : vector<4xf32>, |
4 | 5 | %arg3 : vector<16xbf16>, %arg4 : vector<8xf16>, %arg5 : vector<8xbf16>, |
5 | 6 | %arg6 : vector<16xi8>, %arg7 : vector<8xi32>, %arg8 : vector<4xi32>, |
6 | 7 | %arg9 : vector<16xui8>, %arg10 : vector<16xi4>, %arg11 : vector<8xi4>) { |
7 | 8 | // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32> |
8 | | - amdgpu.wmma %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> |
| 9 | + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg1 : vector<16xf16>, vector<16xf16>, vector<8xf32> |
9 | 10 | // CHECK: rocdl.wmma.f32.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<4xf32>) -> vector<4xf32> |
10 | | - amdgpu.wmma %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> |
| 11 | + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg2 : vector<16xf16>, vector<16xf16>, vector<4xf32> |
11 | 12 | // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xf32>) -> vector<8xf32> |
12 | | - amdgpu.wmma %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> |
| 13 | + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg1 : vector<16xbf16>, vector<16xbf16>, vector<8xf32> |
13 | 14 | // CHECK: rocdl.wmma.f32.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<4xf32>) -> vector<4xf32> |
14 | | - amdgpu.wmma %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> |
| 15 | + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg2 : vector<16xbf16>, vector<16xbf16>, vector<4xf32> |
15 | 16 | // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<16xf16>, i1) -> vector<16xf16> |
16 | | - amdgpu.wmma %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> |
| 17 | + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg0 {subwordOffset = 1 : i32}: vector<16xf16>, vector<16xf16>, vector<16xf16> |
17 | 18 | // CHECK: rocdl.wmma.f16.16x16x16.f16{{.*}}: (vector<16xf16>, vector<16xf16>, vector<8xf16>, i1) -> vector<8xf16> |
18 | | - amdgpu.wmma %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> |
| 19 | + amdgpu.wmma 16x16x16 %arg0 * %arg0 + %arg4 {subwordOffset = 0 : i32}: vector<16xf16>, vector<16xf16>, vector<8xf16> |
19 | 20 | // CHECK: %[[raw_bf16x16:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<16xi16>, i1) -> vector<16xi16> |
20 | 21 | // CHECK-NEXT: llvm.bitcast %[[raw_bf16x16]] : vector<16xi16> to vector<16xbf16> |
21 | | - amdgpu.wmma %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> |
| 22 | + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg3 {subwordOffset = 1 : i32}: vector<16xbf16>, vector<16xbf16>, vector<16xbf16> |
22 | 23 | // CHECK: %[[raw_bf16x8:.+]] = rocdl.wmma.bf16.16x16x16.bf16{{.*}}: (vector<16xi16>, vector<16xi16>, vector<8xi16>, i1) -> vector<8xi16> |
23 | 24 | // CHECK-NEXT: llvm.bitcast %[[raw_bf16x8]] : vector<8xi16> to vector<8xbf16> |
24 | | - amdgpu.wmma %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> |
| 25 | + amdgpu.wmma 16x16x16 %arg3 * %arg3 + %arg5 {subwordOffset = 0 : i32}: vector<16xbf16>, vector<16xbf16>, vector<8xbf16> |
25 | 26 | // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<8xi32>, i1) -> vector<8xi32> |
26 | | - amdgpu.wmma %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> |
| 27 | + amdgpu.wmma 16x16x16 %arg6 * %arg6 + %arg7 {clamp}: vector<16xi8>, vector<16xi8>, vector<8xi32> |
27 | 28 | // CHECK: rocdl.wmma.i32.16x16x16.iu8{{.*}}: (i1, vector<4xi32>, i1, vector<4xi32>, vector<4xi32>, i1) -> vector<4xi32> |
28 | | - amdgpu.wmma %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> |
| 29 | + amdgpu.wmma 16x16x16 %arg9 * %arg9 + %arg8 {unsignedA, unsignedB, clamp}: vector<16xui8>, vector<16xui8>, vector<4xi32> |
29 | 30 | // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, vector<2xi32>, i1, vector<2xi32>, vector<8xi32>, i1) -> vector<8xi32> |
30 | | - amdgpu.wmma %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> |
| 31 | + amdgpu.wmma 16x16x16 %arg10 * %arg10 + %arg7 {clamp}: vector<16xi4>, vector<16xi4>, vector<8xi32> |
31 | 32 | // CHECK: rocdl.wmma.i32.16x16x16.iu4{{.*}}: (i1, i32, i1, i32, vector<4xi32>, i1) -> vector<4xi32> |
32 | | - amdgpu.wmma %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> |
| 33 | + amdgpu.wmma 16x16x16 %arg11 * %arg11 + %arg8 {clamp}: vector<8xi4>, vector<8xi4>, vector<4xi32> |
33 | 34 |
|
34 | 35 | func.return |
35 | 36 | } |
0 commit comments