@@ -989,21 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
989989 smfma.getN (), smfma.getK (), 1u , chipset);
990990}
991991
992- // / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993- // / if one exists. This includes checking to ensure the intrinsic is supported
994- // / on the architecture you are compiling for.
995- static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
996- Chipset chipset) {
997- auto sourceVectorType = cast<VectorType>(wmma.getSourceA ().getType ());
998- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB ().getType ());
999- auto destVectorType = cast<VectorType>(wmma.getDestC ().getType ());
1000- Type elemSourceType = sourceVectorType.getElementType ();
1001- Type elemBSourceType = sourceBVectorType.getElementType ();
1002- Type elemDestType = destVectorType.getElementType ();
1003-
1004- const uint32_t k = wmma.getK ();
1005-
992+ // / Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993+ // / for RDNA3/4 architectures.
994+ static std::optional<StringRef>
995+ wmmaOpToIntrinsicRDNA (Type elemSourceType, Type elemBSourceType,
996+ Type elemDestType, uint32_t k, bool isRDNA3) {
997+ using fp8 = Float8E4M3FNType;
998+ using bf8 = Float8E5M2Type;
999+
1000+ // Handle k == 16 for RDNA3/4.
10061001 if (k == 16 ) {
1002+ // Common patterns for RDNA3 and RDNA4.
10071003 if (elemSourceType.isF16 () && elemDestType.isF32 ())
10081004 return ROCDL::wmma_f32_16x16x16_f16::getOperationName ();
10091005 if (elemSourceType.isBF16 () && elemDestType.isF32 ())
@@ -1014,39 +1010,160 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10141010 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
10151011 if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
10161012 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
1017- if (chipset.majorVersion == 11 ) {
1013+
1014+ // RDNA3 specific patterns.
1015+ if (isRDNA3) {
10181016 if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
10191017 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
1018+ return std::nullopt ;
10201019 }
1021- }
1022- if (chipset.majorVersion < 12 )
1023- return std::nullopt ;
10241020
1025- // gfx12+
1026- if (k == 16 ) {
1027- if (isa<Float8E4M3FNType>(elemSourceType) &&
1028- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
1021+ // RDNA4 specific patterns (fp8/bf8).
1022+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
1023+ elemDestType.isF32 ())
10291024 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName ();
1030- if (isa<Float8E4M3FNType >(elemSourceType) &&
1031- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
1025+ if (isa<fp8 >(elemSourceType) && isa<bf8>(elemBSourceType ) &&
1026+ elemDestType.isF32 ())
10321027 return ROCDL::wmma_f32_16x16x16_fp8_bf8::getOperationName ();
1033- if (isa<Float8E5M2Type >(elemSourceType) &&
1034- isa<Float8E5M2Type>(elemBSourceType) && elemDestType.isF32 ())
1028+ if (isa<bf8 >(elemSourceType) && isa<bf8>(elemBSourceType ) &&
1029+ elemDestType.isF32 ())
10351030 return ROCDL::wmma_f32_16x16x16_bf8_bf8::getOperationName ();
1036- if (isa<Float8E5M2Type >(elemSourceType) &&
1037- isa<Float8E4M3FNType>(elemBSourceType) && elemDestType.isF32 ())
1031+ if (isa<bf8 >(elemSourceType) && isa<fp8>(elemBSourceType ) &&
1032+ elemDestType.isF32 ())
10381033 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName ();
10391034 if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
10401035 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
10411036
10421037 return std::nullopt ;
10431038 }
1044- if (k == 32 ) {
1039+
1040+ // Handle k == 32 for RDNA4.
1041+ if (k == 32 && !isRDNA3) {
10451042 if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
10461043 return ROCDL::wmma_i32_16x16x32_iu4::getOperationName ();
1044+ }
1045+
1046+ llvm_unreachable (" Unsupported k value" );
1047+ }
1048+
1049+ // / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1050+ // / for the gfx1250 architecture.
1051+ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250 (Type elemSourceType,
1052+ Type elemBSourceType,
1053+ Type elemDestType,
1054+ uint32_t k) {
1055+ using fp8 = Float8E4M3FNType;
1056+ using bf8 = Float8E5M2Type;
1057+
1058+ if (k == 4 ) {
1059+ if (elemSourceType.isF32 () && elemDestType.isF32 ())
1060+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName ();
1061+
10471062 return std::nullopt ;
10481063 }
10491064
1065+ if (k == 32 ) {
1066+ if (elemSourceType.isF16 () && elemDestType.isF32 ())
1067+ return ROCDL::wmma_f32_16x16x32_f16::getOperationName ();
1068+ if (elemSourceType.isBF16 () && elemDestType.isF32 ())
1069+ return ROCDL::wmma_f32_16x16x32_bf16::getOperationName ();
1070+ if (elemSourceType.isF16 () && elemDestType.isF16 ())
1071+ return ROCDL::wmma_f16_16x16x32_f16::getOperationName ();
1072+ if (elemSourceType.isBF16 () && elemDestType.isBF16 ())
1073+ return ROCDL::wmma_bf16_16x16x32_bf16::getOperationName ();
1074+
1075+ return std::nullopt ;
1076+ }
1077+
1078+ if (k == 64 ) {
1079+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1080+ if (elemDestType.isF32 ())
1081+ return ROCDL::wmma_f32_16x16x64_fp8_fp8::getOperationName ();
1082+ if (elemDestType.isF16 ())
1083+ return ROCDL::wmma_f16_16x16x64_fp8_fp8::getOperationName ();
1084+ }
1085+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1086+ if (elemDestType.isF32 ())
1087+ return ROCDL::wmma_f32_16x16x64_fp8_bf8::getOperationName ();
1088+ if (elemDestType.isF16 ())
1089+ return ROCDL::wmma_f16_16x16x64_fp8_bf8::getOperationName ();
1090+ }
1091+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1092+ if (elemDestType.isF32 ())
1093+ return ROCDL::wmma_f32_16x16x64_bf8_bf8::getOperationName ();
1094+ if (elemDestType.isF16 ())
1095+ return ROCDL::wmma_f16_16x16x64_bf8_bf8::getOperationName ();
1096+ }
1097+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1098+ if (elemDestType.isF32 ())
1099+ return ROCDL::wmma_f32_16x16x64_bf8_fp8::getOperationName ();
1100+ if (elemDestType.isF16 ())
1101+ return ROCDL::wmma_f16_16x16x64_bf8_fp8::getOperationName ();
1102+ }
1103+ if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
1104+ return ROCDL::wmma_i32_16x16x64_iu8::getOperationName ();
1105+
1106+ return std::nullopt ;
1107+ }
1108+
1109+ if (k == 128 ) {
1110+ if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1111+ if (elemDestType.isF32 ())
1112+ return ROCDL::wmma_f32_16x16x128_fp8_fp8::getOperationName ();
1113+ if (elemDestType.isF16 ())
1114+ return ROCDL::wmma_f16_16x16x128_fp8_fp8::getOperationName ();
1115+ }
1116+ if (isa<fp8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1117+ if (elemDestType.isF32 ())
1118+ return ROCDL::wmma_f32_16x16x128_fp8_bf8::getOperationName ();
1119+ if (elemDestType.isF16 ())
1120+ return ROCDL::wmma_f16_16x16x128_fp8_bf8::getOperationName ();
1121+ }
1122+ if (isa<bf8>(elemSourceType) && isa<bf8>(elemBSourceType)) {
1123+ if (elemDestType.isF32 ())
1124+ return ROCDL::wmma_f32_16x16x128_bf8_bf8::getOperationName ();
1125+ if (elemDestType.isF16 ())
1126+ return ROCDL::wmma_f16_16x16x128_bf8_bf8::getOperationName ();
1127+ }
1128+ if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
1129+ if (elemDestType.isF32 ())
1130+ return ROCDL::wmma_f32_16x16x128_bf8_fp8::getOperationName ();
1131+ if (elemDestType.isF16 ())
1132+ return ROCDL::wmma_f16_16x16x128_bf8_fp8::getOperationName ();
1133+ }
1134+
1135+ return std::nullopt ;
1136+ }
1137+
1138+ llvm_unreachable (" Unsupported k value" );
1139+ }
1140+
1141+ // / Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1142+ // / if one exists. This includes checking to ensure the intrinsic is supported
1143+ // / on the architecture you are compiling for.
1144+ static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
1145+ Chipset chipset) {
1146+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA ().getType ());
1147+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB ().getType ());
1148+ auto destVectorType = cast<VectorType>(wmma.getDestC ().getType ());
1149+ Type elemSourceType = sourceVectorType.getElementType ();
1150+ Type elemBSourceType = sourceBVectorType.getElementType ();
1151+ Type elemDestType = destVectorType.getElementType ();
1152+
1153+ const uint32_t k = wmma.getK ();
1154+ const bool isRDNA3 = chipset.majorVersion == 11 ;
1155+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0 ;
1156+
1157+ // Handle RDNA3 and RDNA4.
1158+ if (isRDNA3 || isRDNA4)
1159+ return wmmaOpToIntrinsicRDNA (elemSourceType, elemBSourceType, elemDestType,
1160+ k, isRDNA3);
1161+
1162+ // Handle gfx1250.
1163+ if (chipset == Chipset{12 , 5 , 0 })
1164+ return wmmaOpToIntrinsicGfx1250 (elemSourceType, elemBSourceType,
1165+ elemDestType, k);
1166+
10501167 llvm_unreachable (" unhandled WMMA case" );
10511168}
10521169
0 commit comments