@@ -16946,7 +16946,7 @@ canFoldToVWWithSameExtensionImpl(SDNode *Root, const NodeExtensionHelper &LHS,
1694616946 Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
1694716947 /*RHSExt=*/{ExtKind::FPExt});
1694816948 if ((AllowExtMask & ExtKind::BF16Ext) && LHS.SupportsBF16Ext &&
16949- RHS.SupportsBF16Ext && Root->getOpcode() == RISCVISD::VFMADD_VL )
16949+ RHS.SupportsBF16Ext)
1695016950 return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
1695116951 Root, LHS, /*LHSExt=*/{ExtKind::BF16Ext}, RHS,
1695216952 /*RHSExt=*/{ExtKind::BF16Ext});
@@ -16963,10 +16963,9 @@ static std::optional<CombineResult>
1696316963canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1696416964 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1696516965 const RISCVSubtarget &Subtarget) {
16966- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS,
16967- ExtKind::ZExt | ExtKind::SExt |
16968- ExtKind::FPExt | ExtKind::BF16Ext,
16969- DAG, Subtarget);
16966+ return canFoldToVWWithSameExtensionImpl(
16967+ Root, LHS, RHS, ExtKind::ZExt | ExtKind::SExt | ExtKind::FPExt, DAG,
16968+ Subtarget);
1697016969}
1697116970
1697216971/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
@@ -17033,6 +17032,18 @@ canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1703317032 Subtarget);
1703417033}
1703517034
17035+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17036+ ///
17037+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17038+ /// can be used to apply the pattern.
17039+ static std::optional<CombineResult>
17040+ canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17041+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17042+ const RISCVSubtarget &Subtarget) {
17043+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17044+ Subtarget);
17045+ }
17046+
1703617047/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
1703717048///
1703817049/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17072,6 +17083,8 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1707217083 case RISCVISD::VFNMADD_VL:
1707317084 case RISCVISD::VFNMSUB_VL:
1707417085 Strategies.push_back(canFoldToVWWithSameExtension);
17086+ if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17087+ Strategies.push_back(canFoldToVWWithBF16EXT);
1707517088 break;
1707617089 case ISD::MUL:
1707717090 case RISCVISD::MUL_VL:
0 commit comments