@@ -55,46 +55,47 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5555func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
5656 %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
5757 %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
58- %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>) {
58+ %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>,
59+ %arg7 : vector <4 xi8 >, %arg8 : i8 ) {
5960
60- // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
61- // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
6261 // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
62+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
63+ // CHECK: llvm.bitcast
6364
64- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
65- amdgpu.scaled_mfma % arg2 * % arg2 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E4 M3 FN>, vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
66- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
67- amdgpu.scaled_mfma % arg2 * % arg2 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E4 M3 FN>, vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
65+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
66+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg2 ) * ( %arg8 [ 1 ] * % arg2 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf8 E4 M3 FN>, vector <4 x i8 >, vector < 32 xf8 E4 M3 FN>, i8 , vector <16 xf32 >
67+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
68+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg2 ) * ( %arg8 [ 1 ] * % arg2 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf8 E4 M3 FN>, vector <4 x i8 >, vector < 32 xf8 E4 M3 FN>, i8 , vector <4 xf32 >
6869
6970 // CHECK: llvm.bitcast
7071
71- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1 ]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
72- amdgpu.scaled_mfma % arg3 * % arg3 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E5 M2 >, vector <32 xf8 E5 M2 >, vector <16 xf32 >
73- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1 ]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
74- amdgpu.scaled_mfma % arg3 * % arg3 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E5 M2 >, vector <32 xf8 E5 M2 >, vector <4 xf32 >
72+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
73+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg3 ) * ( %arg8 [ 1 ] * % arg3 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf8 E5 M2 >, vector <4 x i8 >, vector < 32 xf8 E5 M2 >, i8 , vector <16 xf32 >
74+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
75+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg3 ) * ( %arg8 [ 1 ] * % arg3 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf8 E5 M2 >, vector <4 x i8 >, vector < 32 xf8 E5 M2 >, i8 , vector <4 xf32 >
7576
7677 // CHECK: llvm.bitcast
7778
78- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2 ]], %[[c2]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
79- amdgpu.scaled_mfma % arg4 * % arg4 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E2 M3 FN>, vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
80- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2 ]], %[[c2]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
81- amdgpu.scaled_mfma % arg4 * % arg4 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E2 M3 FN>, vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
79+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
80+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg4 ) * ( %arg8 [ 1 ] * % arg4 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf6 E2 M3 FN>, vector <4 x i8 >, vector < 32 xf6 E2 M3 FN>, i8 , vector <16 xf32 >
81+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
82+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg4 ) * ( %arg8 [ 1 ] * % arg4 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf6 E2 M3 FN>, vector <4 x i8 >, vector < 32 xf6 E2 M3 FN>, i8 , vector <4 xf32 >
8283
8384 // CHECK: llvm.bitcast
8485 // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
8586
86- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3 ]], %[[c3]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
87- amdgpu.scaled_mfma % arg5 * % arg5 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E3 M2 FN>, vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
88- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3 ]], %[[c3]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
89- amdgpu.scaled_mfma % arg5 * % arg5 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E3 M2 FN>, vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
87+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
88+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg5 ) * ( %arg8 [ 1 ] * % arg5 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf6 E3 M2 FN>, vector <4 x i8 >, vector < 32 xf6 E3 M2 FN>, i8 , vector <16 xf32 >
89+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
90+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg5 ) * ( %arg8 [ 1 ] * % arg5 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf6 E3 M2 FN>, vector <4 x i8 >, vector < 32 xf6 E3 M2 FN>, i8 , vector <4 xf32 >
9091
9192 // CHECK: llvm.bitcast
9293 // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
9394
94- // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4 ]], %[[c4]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32 , i32, i32) -> vector<16xf32>
95- amdgpu.scaled_mfma % arg6 * % arg6 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf4 E2 M1 FN>, vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
96- // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4 ]], %[[c4]], %[[ c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32 , i32, i32) -> vector<4xf32>
97- amdgpu.scaled_mfma % arg6 * % arg6 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf4 E2 M1 FN>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
95+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i8 , i32, i32) -> vector<16xf32>
96+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg6 ) * ( %arg8 [ 1 ] * % arg6 ) + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 } : vector <32 xf4 E2 M1 FN>, vector <4 x i8 >, vector < 32 xf4 E2 M1 FN>, i8 , vector <16 xf32 >
97+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0 ]], %[[c1]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i8 , i32, i32) -> vector<4xf32>
98+ amdgpu.scaled_mfma ( %arg7 [ 0 ] * % arg6 ) * ( %arg8 [ 1 ] * % arg6 ) + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 } : vector <32 xf4 E2 M1 FN>, vector <4 x i8 >, vector < 32 xf4 E2 M1 FN>, i8 , vector <4 xf32 >
9899
99100 func.return
100101}
0 commit comments