@@ -51,3 +51,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
5151
5252 func.return
5353}
54+
55+ func.func @scaled_mfma_to_rocdl (%arg0 : vector <16 xf32 >,
56+ %arg1 : vector <4 xf32 >, %arg2 : vector <32 xf8 E4 M3 FN>,
57+ %arg3 : vector <32 xf8 E5 M2 >, %arg4 : vector <32 xf6 E2 M3 FN>,
58+ %arg5 : vector <32 xf6 E3 M2 FN>, %arg6 : vector <32 xf4 E2 M1 FN>) {
59+
60+ // CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
61+ // CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
62+ // CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
63+
64+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
65+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E4 M3 FN>, vector <32 xf8 E4 M3 FN>, vector <16 xf32 >
66+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
67+ amdgpu.scaled_mfma %arg2 * %arg2 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E4 M3 FN>, vector <32 xf8 E4 M3 FN>, vector <4 xf32 >
68+
69+ // CHECK: llvm.bitcast
70+
71+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
72+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E5 M2 >, vector <32 xf8 E5 M2 >, vector <16 xf32 >
73+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
74+ amdgpu.scaled_mfma %arg3 * %arg3 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf8 E5 M2 >, vector <32 xf8 E5 M2 >, vector <4 xf32 >
75+
76+ // CHECK: llvm.bitcast
77+
78+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
79+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E2 M3 FN>, vector <32 xf6 E2 M3 FN>, vector <16 xf32 >
80+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
81+ amdgpu.scaled_mfma %arg4 * %arg4 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E2 M3 FN>, vector <32 xf6 E2 M3 FN>, vector <4 xf32 >
82+
83+ // CHECK: llvm.bitcast
84+ // CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
85+
86+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
87+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E3 M2 FN>, vector <32 xf6 E3 M2 FN>, vector <16 xf32 >
88+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
89+ amdgpu.scaled_mfma %arg5 * %arg5 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf6 E3 M2 FN>, vector <32 xf6 E3 M2 FN>, vector <4 xf32 >
90+
91+ // CHECK: llvm.bitcast
92+ // CHECK: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
93+
94+ // CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
95+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg0 { k = 64 : i32 , m = 32 : i32 , n = 32 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf4 E2 M1 FN>, vector <32 xf4 E2 M1 FN>, vector <16 xf32 >
96+ // CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c1]], %[[c1]], %[[c2]], %[[c2]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
97+ amdgpu.scaled_mfma %arg6 * %arg6 + %arg1 { k = 128 : i32 , m = 16 : i32 , n = 16 : i32 , scaleA = 1 : i32 , opselA = 1 : i32 , scaleB = 2 : i32 , opselB = 2 : i32 } : vector <32 xf4 E2 M1 FN>, vector <32 xf4 E2 M1 FN>, vector <4 xf32 >
98+
99+ func.return
100+ }
0 commit comments