@@ -60,42 +60,43 @@ func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
6060
6161 // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
6262 // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
63- // CHECK: llvm.bitcast
63+ // CHECK: %[[c2:.+]] = llvm.bitcast{{.*}} : vector<4xi8> to i32
64+ // CHECK: %[[c3:.+]] = llvm.zext{{.*}} : i8 to i32
6465
65- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
66- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf 8 E 4 M 3 FN >, vector <4 x i8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <16 xf32 >
67- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
68- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * ( %arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf 8 E 4 M 3 FN >, vector <4 x i8 >, vector <32 xf8 E4 M3 FN>, i8 , vector <4 xf32 >
66+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
67+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * (%arg8 [ 1 ] * %arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf 8 E 4 M 3 FN >, i8 , vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
68+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
69+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg2 ) * (%arg8 [ 1 ] * %arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf 8 E 4 M 3 FN >, i8 , vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
6970
7071 // CHECK: llvm.bitcast
7172
72- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
73- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf 8 E 5 M 2 >, vector <4 x i8 >, vector <32 xf8 E5 M2 >, i8 , vector <16 xf32 >
74- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
75- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * ( %arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf 8 E 5 M 2 >, vector <4 x i8 >, vector <32 xf8 E5 M2 >, i8 , vector <4 xf32 >
73+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
74+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * (%arg8 [ 1 ] * %arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf 8 E 5 M 2 >, i8 , vector <32 xf8 E5 M2 >, vector <16 xf32 >
75+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
76+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg3 ) * (%arg8 [ 1 ] * %arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf 8 E 5 M 2 >, i8 , vector <32 xf8 E5 M2 >, vector <4 xf32 >
7677
7778 // CHECK: llvm.bitcast
7879
79- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
80- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf 6 E 2 M 3 FN >, vector <4 x i8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <16 xf32 >
81- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
82- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * ( %arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf 6 E 2 M 3 FN >, vector <4 x i8 >, vector <32 xf6 E2 M3 FN>, i8 , vector <4 xf32 >
80+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
81+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * (%arg8 [ 1 ] * %arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf 6 E 2 M 3 FN >, i8 , vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
82+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
83+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg4 ) * (%arg8 [ 1 ] * %arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf 6 E 2 M 3 FN >, i8 , vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
8384
8485 // CHECK: llvm.bitcast
85- // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
86+ // CHECK: llvm.mlir.constant(3 : i32) : i32
8687
87- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
88- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf 6 E 3 M 2 FN >, vector <4 x i8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <16 xf32 >
89- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
90- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * ( %arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf 6 E 3 M 2 FN >, vector <4 x i8 >, vector <32 xf6 E3 M2 FN>, i8 , vector <4 xf32 >
88+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
89+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * (%arg8 [ 1 ] * %arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf 6 E 3 M 2 FN >, i8 , vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
90+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
91+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg5 ) * (%arg8 [ 1 ] * %arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf 6 E 3 M 2 FN >, i8 , vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
9192
9293 // CHECK: llvm.bitcast
93- // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
94+ // CHECK: llvm.mlir.constant(4 : i32) : i32
9495
95- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
96- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf 4 E 2 M 1 FN >, vector <4 x i8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <16 xf32 >
97- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
98- amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * ( %arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf 4 E 2 M 1 FN >, vector <4 x i8 >, vector <32 xf4 E2 M1 FN>, i8 , vector <4 xf32 >
96+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
97+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * (%arg8 [ 1 ] * %arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <4 x i8 >, vector <32 xf 4 E 2 M 1 FN >, i8 , vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
98+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c2]], %[[ c1]], %[[c3]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
99+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * %arg6 ) * (%arg8 [ 1 ] * %arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <4 x i8 >, vector <32 xf 4 E 2 M 1 FN >, i8 , vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
99100
100101 func.return
101102}
0 commit comments