@@ -17,8 +17,8 @@ using namespace mlir;
1717using namespace mlir ::rock;
1818
1919// The static initialization will follow the defined ordering
20- // of the below lambdas
21- auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
20+ // of the below lambda
21+ static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
2222 static llvm::StringMap<MfmaInsnInfo> insnInfo{
2323 // fp32
2424 {ROCDL::mfma_f32_32x32x1f32::getOperationName (),
@@ -37,8 +37,12 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
3737 {MfmaTypeId::Fp16TyId, 32 , 4 , 2 }},
3838 {ROCDL::mfma_f32_32x32x8f16::getOperationName (),
3939 {MfmaTypeId::Fp16TyId, 32 , 8 , 1 }},
40+ {ROCDL::mfma_f32_32x32x16_f16::getOperationName (),
41+ {MfmaTypeId::Fp16TyId, 32 , 16 , 1 }},
4042 {ROCDL::mfma_f32_16x16x4f16::getOperationName (),
4143 {MfmaTypeId::Fp16TyId, 16 , 4 , 4 }},
44+ {ROCDL::mfma_f32_16x16x32_f16::getOperationName (),
45+ {MfmaTypeId::Fp16TyId, 16 , 32 , 1 }},
4246 {ROCDL::mfma_f32_16x16x16f16::getOperationName (),
4347 {MfmaTypeId::Fp16TyId, 16 , 16 , 1 }},
4448 {ROCDL::mfma_f32_4x4x4f16::getOperationName (),
@@ -47,10 +51,14 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
4751 // bf16
4852 {ROCDL::mfma_f32_32x32x2bf16::getOperationName (),
4953 {MfmaTypeId::Bf16TyId, 32 , 2 , 2 }},
54+ {ROCDL::mfma_f32_32x32x16_bf16::getOperationName (),
55+ {MfmaTypeId::Bf16TyId, 32 , 16 , 1 }},
5056 {ROCDL::mfma_f32_32x32x4bf16::getOperationName (),
5157 {MfmaTypeId::Bf16TyId, 32 , 4 , 1 }},
5258 {ROCDL::mfma_f32_16x16x2bf16::getOperationName (),
5359 {MfmaTypeId::Bf16TyId, 16 , 2 , 4 }},
60+ {ROCDL::mfma_f32_16x16x32_bf16::getOperationName (),
61+ {MfmaTypeId::Bf16TyId, 16 , 32 , 1 }},
5462 {ROCDL::mfma_f32_16x16x8bf16::getOperationName (),
5563 {MfmaTypeId::Bf16TyId, 16 , 8 , 1 }},
5664 {ROCDL::mfma_f32_4x4x2bf16::getOperationName (),
@@ -77,8 +85,12 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
7785 // i8 (new)
7886 {ROCDL::mfma_i32_32x32x16_i8::getOperationName (),
7987 {MfmaTypeId::I8TyId, 32 , 16 , 1 }},
88+ {ROCDL::mfma_i32_32x32x32_i8::getOperationName (),
89+ {MfmaTypeId::I8TyId, 32 , 32 , 1 }},
8090 {ROCDL::mfma_i32_16x16x32_i8::getOperationName (),
8191 {MfmaTypeId::I8TyId, 16 , 32 , 1 }},
92+ {ROCDL::mfma_i32_16x16x64_i8::getOperationName (),
93+ {MfmaTypeId::I8TyId, 16 , 64 , 1 }},
8294
8395 // fp8
8496 {ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName (),
@@ -178,7 +190,7 @@ static MfmaInsnAttr deriveAttr(MfmaInsnInfo info) {
178190 isKReduction};
179191}
180192
181- auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
193+ static auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
182194 static llvm::StringMap<MfmaInsnAttr> insnDb;
183195 static std::once_flag once;
184196 std::call_once (once, [&]() {
@@ -194,7 +206,7 @@ auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
194206using MfmaInsnGroupMap =
195207 llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
196208 MfmaInsnGroupSelectKeyInfo>;
197- auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
209+ static auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
198210 using amdgpu::MFMAPermB;
199211 static MfmaInsnGroupMap
200212 // f32
@@ -242,7 +254,8 @@ auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
242254 return groupAttrMap;
243255};
244256
245- auto getMfmaInsnGroupAttrMapGfx908Bf16 = []() -> const MfmaInsnGroupMap & {
257+ static auto getMfmaInsnGroupAttrMapGfx908Bf16 =
258+ []() -> const MfmaInsnGroupMap & {
246259 using amdgpu::MFMAPermB;
247260 static MfmaInsnGroupMap
248261 // bf16
@@ -269,7 +282,7 @@ auto getMfmaInsnGroupAttrMapGfx908Bf16 = []() -> const MfmaInsnGroupMap & {
269282 return groupAttrMap;
270283};
271284
272- auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
285+ static auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
273286 using amdgpu::MFMAPermB;
274287 static llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
275288 MfmaInsnGroupSelectKeyInfo>
@@ -297,7 +310,7 @@ auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
297310 return groupAttrMap;
298311};
299312
300- auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
313+ static auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
301314 using amdgpu::MFMAPermB;
302315 static llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
303316 MfmaInsnGroupSelectKeyInfo>
@@ -321,7 +334,7 @@ auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
321334};
322335
323336// New I8 and all Float8
324- auto getMfmaInsnGroupAttrMapGfx942Plus = []() {
337+ static auto getMfmaInsnGroupAttrMapGfx942 = []() {
325338 using amdgpu::MFMAPermB;
326339 static MfmaInsnGroupMap
327340 // Int8
@@ -407,6 +420,28 @@ auto getMfmaInsnGroupAttrMapGfx942Plus = []() {
407420 return groupAttrMap;
408421};
409422
423+ static auto getMfmaInsnGroupAttrMapGfx950 = []() {
424+ static MfmaInsnGroupMap groupAttrMap{
425+ // fp16 double rate
426+ {{MfmaTypeId::Fp16TyId, 16 , 16 },
427+ {ROCDL::mfma_f32_16x16x32_f16::getOperationName ()}},
428+ {{MfmaTypeId::Fp16TyId, 32 , 32 },
429+ {ROCDL::mfma_f32_32x32x16_f16::getOperationName ()}},
430+ // bfp16 double rate
431+ {{MfmaTypeId::Bf16TyId, 16 , 16 },
432+ {ROCDL::mfma_f32_16x16x32_bf16::getOperationName ()}},
433+ {{MfmaTypeId::Bf16TyId, 32 , 32 },
434+ {ROCDL::mfma_f32_32x32x16_bf16::getOperationName ()}},
435+ // i8 double rate
436+ {{MfmaTypeId::I8TyId, 16 , 16 },
437+ {ROCDL::mfma_i32_16x16x64_i8::getOperationName ()}},
438+ {{MfmaTypeId::I8TyId, 32 , 32 },
439+ {ROCDL::mfma_i32_32x32x32_i8::getOperationName ()}}
440+
441+ };
442+ return groupAttrMap;
443+ };
444+
410445FailureOr<MfmaInsn> MfmaInsn::select (StringRef mfmaInsn) {
411446 auto mfmaInsnAttrMap = getMfmaInsnAttrMap ();
412447 auto it = mfmaInsnAttrMap.find (mfmaInsn);
@@ -546,13 +581,35 @@ FailureOr<MfmaInsnGroup> MfmaInsnGroup::select(Type elementTypeA,
546581 result = MfmaInsnGroup (elementTypeA, elementTypeB, *maybeInsn, groupAttr);
547582 }
548583 };
549- bool hasOldBf16 = arch.contains (" gfx908" );
550- bool isPreGfx942 = arch.contains (" gfx908" ) || arch.contains (" gfx90a" );
551- if (elementTypeA.isBF16 ())
552- selectFrom (hasOldBf16 ? getMfmaInsnGroupAttrMapGfx908Bf16 ()
553- : getMfmaInsnGroupAttrMapGfx90aPlusBf16 ());
554- selectFrom (isPreGfx942 ? getMfmaInsnGroupAttrMapPreGfx942Int8 ()
555- : getMfmaInsnGroupAttrMapGfx942Plus ());
584+ bool isGfx908 = arch.contains (" gfx908" );
585+ bool isGfx90a = arch.contains (" gfx908" ) || arch.contains (" gfx90a" );
586+ bool isGfx94x = arch.contains (" gfx942" ) || arch.contains (" gfx940" );
587+ bool isGfx95x = arch.contains (" gfx950" );
588+ // TODO: refactor this later to not keep multiple maps for different arches
589+ if (elementTypeA.isBF16 ()) {
590+ if (isGfx908) {
591+ selectFrom (getMfmaInsnGroupAttrMapGfx908Bf16 ());
592+ } else if (isGfx94x || isGfx90a) {
593+ selectFrom (getMfmaInsnGroupAttrMapGfx90aPlusBf16 ());
594+ } else {
595+ // gfx950 has double rate instructions. Select from those first.
596+ selectFrom (getMfmaInsnGroupAttrMapGfx950 ());
597+ selectFrom (getMfmaInsnGroupAttrMapGfx90aPlusBf16 ());
598+ }
599+ }
600+
601+ if (isGfx908 || isGfx90a) {
602+ selectFrom (getMfmaInsnGroupAttrMapPreGfx942Int8 ());
603+ } else if (isGfx94x) {
604+ selectFrom (getMfmaInsnGroupAttrMapGfx942 ());
605+ } else if (isGfx95x) {
606+ // select from new double rate instructions first
607+ selectFrom (getMfmaInsnGroupAttrMapGfx950 ());
608+ // all previous instructions are still valid for gfx950
609+ selectFrom (getMfmaInsnGroupAttrMapGfx942 ());
610+ }
611+ // select from all available instructions on all architectures if it is not
612+ // been selected yet
556613 selectFrom (getMfmaInsnGroupAttrMapAllArch ());
557614 if (failed (result)) {
558615 LLVM_DEBUG (llvm::dbgs () << " No match found in MFMA database\n " );
0 commit comments