@@ -53,52 +53,52 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5353}
5454
5555// CHECK-LABEL: func @scaled_mfma_to_rocdl(
56- // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xi8 >, %[[ARG8:.*]]: i8
56+ // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU >, %[[ARG8:.*]]: f8E8M0FNU
5757func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
5858 %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
5959 %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
6060 %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
61- %arg7 : vector <4 x i8 >, %arg8 : i8 ) {
61+ %arg7 : vector <4 xf 8 E 8 M 0 FNU >, %arg8 : f8E8M0FNU ) {
6262
6363 // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6464 // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
65- // CHECK: %[[b0:.+]] = llvm.bitcast %[[ARG7]] : vector<4xi8> to i32
66- // CHECK: %[[z0:.+]] = llvm.zext %[[ARG8]] : i8 to i32
65+ // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
66+ // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
6767
6868 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
69- amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
69+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
7070 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
71- amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
71+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
7272
7373 // CHECK: llvm.bitcast
7474
7575 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
76- amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf8 E5 M2 >, i8 , vector <32 xf8 E5 M2 >, vector <16 xf32 >
76+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <16 xf32 >
7777 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
78- amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf8 E5 M2 >, i8 , vector <32 xf8 E5 M2 >, vector <4 xf32 >
78+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <4 xf32 >
7979
8080 // CHECK: llvm.bitcast
8181
8282 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
83- amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
83+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
8484 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
85- amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
85+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
8686
8787 // CHECK: llvm.bitcast
8888 // CHECK: llvm.mlir.constant(3 : i32) : i32
8989
9090 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
91- amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
91+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
9292 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
93- amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
93+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
9494
9595 // CHECK: llvm.bitcast
9696 // CHECK: llvm.mlir.constant(4 : i32) : i32
9797
9898 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
99- amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
99+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
100100 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
101- amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
101+ amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf 8 E 8 M 0 FNU >, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
102102
103103 func.return
104104}
0 commit comments