@@ -55,50 +55,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5555// CHECK-LABEL: func @scaled_mfma_to_rocdl(
5656// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
5757func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
58- %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
59- %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
60- %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
61- %arg7 : vector <4 xf8 E8 M0 FNU>, %arg8 : f8E8M0FNU ) {
58+ %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
59+ %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
60+ %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
61+ %arg7 : vector <4 xf8 E8 M0 FNU>, %arg8 : f8E8M0FNU ) {
6262
6363 // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6464 // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
6565 // CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
6666 // CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
6767
6868 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
69- amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
69+ amdgpu.scaled_mfma 32 x 32 x 64 (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg0 : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
7070 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
71- amdgpu.scaled_mfma (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
71+ amdgpu.scaled_mfma 16 x 16 x 128 (%arg7 [0 ] * %arg2 ) * (%arg8 [1 ] * %arg2 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E4 M3 FN>, f8E8M0FNU , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
7272
7373 // CHECK: llvm.bitcast
7474
7575 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
76- amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <16 xf32 >
76+ amdgpu.scaled_mfma 32 x 32 x 64 (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg0 : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <16 xf32 >
7777 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
78- amdgpu.scaled_mfma (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <4 xf32 >
78+ amdgpu.scaled_mfma 16 x 16 x 128 (%arg7 [0 ] * %arg3 ) * (%arg8 [1 ] * %arg3 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <32 xf8 E5 M2 >, f8E8M0FNU , vector <32 xf8 E5 M2 >, vector <4 xf32 >
7979
8080 // CHECK: llvm.bitcast
8181
8282 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
83- amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
83+ amdgpu.scaled_mfma 32 x 32 x 64 (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg0 : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
8484 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
85- amdgpu.scaled_mfma (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
85+ amdgpu.scaled_mfma 16 x 16 x 128 (%arg7 [0 ] * %arg4 ) * (%arg8 [1 ] * %arg4 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E2 M3 FN>, f8E8M0FNU , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
8686
8787 // CHECK: llvm.bitcast
8888 // CHECK: llvm.mlir.constant(3 : i32) : i32
8989
9090 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
91- amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
91+ amdgpu.scaled_mfma 32 x 32 x 64 (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg0 : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
9292 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
93- amdgpu.scaled_mfma (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
93+ amdgpu.scaled_mfma 16 x 16 x 128 (%arg7 [0 ] * %arg5 ) * (%arg8 [1 ] * %arg5 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <32 xf6 E3 M2 FN>, f8E8M0FNU , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
9494
9595 // CHECK: llvm.bitcast
9696 // CHECK: llvm.mlir.constant(4 : i32) : i32
9797
9898 // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
99- amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
99+ amdgpu.scaled_mfma 32 x 32 x 64 (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg0 : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
100100 // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
101- amdgpu.scaled_mfma (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
101+ amdgpu.scaled_mfma 16 x 16 x 128 (%arg7 [0 ] * %arg6 ) * (%arg8 [1 ] * %arg6 ) + %arg1 : vector <4 xf8 E8 M0 FNU>, vector <32 xf4 E2 M1 FN>, f8E8M0FNU , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
102102
103103 func.return
104104}
0 commit comments