@@ -989,26 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
989989 smfma.getN (), smfma.getK (), 1u , chipset);
990990}
991991
992- // / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993- // / if one exists. This includes checking to ensure the intrinsic is supported
994- // / on the architecture you are compiling for.
995- static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
996- Chipset chipset) {
997- auto sourceVectorType = cast<VectorType>(wmma.getSourceA ().getType ());
998- auto sourceBVectorType = cast<VectorType>(wmma.getSourceB ().getType ());
999- auto destVectorType = cast<VectorType>(wmma.getDestC ().getType ());
1000- Type elemSourceType = sourceVectorType.getElementType ();
1001- Type elemBSourceType = sourceBVectorType.getElementType ();
1002- Type elemDestType = destVectorType.getElementType ();
1003-
1004- const uint32_t k = wmma.getK ();
1005- const bool isRDNA3 = chipset.majorVersion == 11 ;
1006- const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0 ;
992+ // / Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993+ // / for RDNA3/4 architectures.
994+ static std::optional<StringRef>
995+ wmmaOpToIntrinsicRDNA (Type elemSourceType, Type elemBSourceType,
996+ Type elemDestType, uint32_t k, bool isRDNA3) {
997+ using fp8 = Float8E4M3FNType;
998+ using bf8 = Float8E5M2Type;
1007999
1000+ // Handle k == 16 for RDNA3/4.
10081001 if (k == 16 ) {
1009- if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
1010- return std::nullopt ;
1011-
1002+ // Common patterns for RDNA3 and RDNA4.
10121003 if (elemSourceType.isF16 () && elemDestType.isF32 ())
10131004 return ROCDL::wmma_f32_16x16x16_f16::getOperationName ();
10141005 if (elemSourceType.isBF16 () && elemDestType.isF32 ())
@@ -1019,22 +1010,15 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10191010 return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName ();
10201011 if (elemSourceType.isInteger (8 ) && elemDestType.isInteger (32 ))
10211012 return ROCDL::wmma_i32_16x16x16_iu8::getOperationName ();
1022- if (chipset.majorVersion == 11 ) {
1013+
1014+ // RDNA3 specific patterns.
1015+ if (isRDNA3) {
10231016 if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
10241017 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
1025- }
1026- }
1027- if (isRDNA3)
1028- return std::nullopt ;
1029-
1030- using fp8 = Float8E4M3FNType;
1031- using bf8 = Float8E5M2Type;
1032-
1033- // gfx12+
1034- if (k == 16 ) {
1035- if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
10361018 return std::nullopt ;
1019+ }
10371020
1021+ // RDNA4 specific patterns (fp8/bf8).
10381022 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
10391023 elemDestType.isF32 ())
10401024 return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName ();
@@ -1047,20 +1031,38 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10471031 if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
10481032 elemDestType.isF32 ())
10491033 return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName ();
1050-
10511034 if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
10521035 return ROCDL::wmma_i32_16x16x16_iu4::getOperationName ();
10531036
10541037 return std::nullopt ;
10551038 }
1056- if (k == 32 ) {
1057- if (isRDNA4) {
1058- if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
1059- return ROCDL::wmma_i32_16x16x32_iu4::getOperationName ();
1060- return std::nullopt ;
1061- }
10621039
1063- // gfx1250
1040+ // Handle k == 32 for RDNA4.
1041+ if (k == 32 && !isRDNA3) {
1042+ if (elemSourceType.isInteger (4 ) && elemDestType.isInteger (32 ))
1043+ return ROCDL::wmma_i32_16x16x32_iu4::getOperationName ();
1044+ }
1045+
1046+ llvm_unreachable (" Unsupported k value" );
1047+ }
1048+
1049+ // / Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1050+ // / for the gfx1250 architecture.
1051+ static std::optional<StringRef> wmmaOpToIntrinsicGfx1250 (Type elemSourceType,
1052+ Type elemBSourceType,
1053+ Type elemDestType,
1054+ uint32_t k) {
1055+ using fp8 = Float8E4M3FNType;
1056+ using bf8 = Float8E5M2Type;
1057+
1058+ if (k == 4 ) {
1059+ if (elemSourceType.isF32 () && elemDestType.isF32 ())
1060+ return ROCDL::wmma_f32_16x16x4_f32::getOperationName ();
1061+
1062+ return std::nullopt ;
1063+ }
1064+
1065+ if (k == 32 ) {
10641066 if (elemSourceType.isF16 () && elemDestType.isF32 ())
10651067 return ROCDL::wmma_f32_16x16x32_f16::getOperationName ();
10661068 if (elemSourceType.isBF16 () && elemDestType.isF32 ())
@@ -1073,16 +1075,6 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10731075 return std::nullopt ;
10741076 }
10751077
1076- if (isRDNA4)
1077- return std::nullopt ;
1078-
1079- // gfx1250
1080- if (k == 4 ) {
1081- if (elemSourceType.isF32 () && elemDestType.isF32 ())
1082- return ROCDL::wmma_f32_16x16x4_f32::getOperationName ();
1083- return std::nullopt ;
1084- }
1085-
10861078 if (k == 64 ) {
10871079 if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
10881080 if (elemDestType.isF32 ())
@@ -1142,6 +1134,36 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
11421134
11431135 return std::nullopt ;
11441136 }
1137+
1138+ llvm_unreachable (" Unsupported k value" );
1139+ }
1140+
1141+ // / Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1142+ // / if one exists. This includes checking to ensure the intrinsic is supported
1143+ // / on the architecture you are compiling for.
1144+ static std::optional<StringRef> wmmaOpToIntrinsic (WMMAOp wmma,
1145+ Chipset chipset) {
1146+ auto sourceVectorType = cast<VectorType>(wmma.getSourceA ().getType ());
1147+ auto sourceBVectorType = cast<VectorType>(wmma.getSourceB ().getType ());
1148+ auto destVectorType = cast<VectorType>(wmma.getDestC ().getType ());
1149+ Type elemSourceType = sourceVectorType.getElementType ();
1150+ Type elemBSourceType = sourceBVectorType.getElementType ();
1151+ Type elemDestType = destVectorType.getElementType ();
1152+
1153+ const uint32_t k = wmma.getK ();
1154+ const bool isRDNA3 = chipset.majorVersion == 11 ;
1155+ const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0 ;
1156+
1157+ // Handle RDNA3 and RDNA4.
1158+ if (isRDNA3 || isRDNA4)
1159+ return wmmaOpToIntrinsicRDNA (elemSourceType, elemBSourceType, elemDestType,
1160+ k, isRDNA3);
1161+
1162+ // Handle gfx1250.
1163+ if (chipset == Chipset{12 , 5 , 0 })
1164+ return wmmaOpToIntrinsicGfx1250 (elemSourceType, elemBSourceType,
1165+ elemDestType, k);
1166+
11451167 llvm_unreachable (" unhandled WMMA case" );
11461168}
11471169
0 commit comments