@@ -16936,18 +16936,9 @@ struct NodeExtensionHelper {
1693616936 case RISCVISD::VWSUBU_W_VL:
1693716937 case RISCVISD::VFWADD_W_VL:
1693816938 case RISCVISD::VFWSUB_W_VL:
16939- if (OperandIdx == 1) {
16940- SupportsZExt =
16941- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
16942- SupportsSExt =
16943- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
16944- SupportsFPExt =
16945- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
16946- // There's no existing extension here, so we don't have to worry about
16947- // making sure it gets removed.
16948- EnforceOneUse = false;
16939+ // Operand 1 can't be changed.
16940+ if (OperandIdx == 1)
1694916941 break;
16950- }
1695116942 [[fallthrough]];
1695216943 default:
1695316944 fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -16985,20 +16976,20 @@ struct NodeExtensionHelper {
1698516976 case RISCVISD::ADD_VL:
1698616977 case RISCVISD::MUL_VL:
1698716978 case RISCVISD::OR_VL:
16988- case RISCVISD::VWADD_W_VL:
16989- case RISCVISD::VWADDU_W_VL:
1699016979 case RISCVISD::FADD_VL:
1699116980 case RISCVISD::FMUL_VL:
16992- case RISCVISD::VFWADD_W_VL:
1699316981 case RISCVISD::VFMADD_VL:
1699416982 case RISCVISD::VFNMSUB_VL:
1699516983 case RISCVISD::VFNMADD_VL:
1699616984 case RISCVISD::VFMSUB_VL:
1699716985 return true;
16986+ case RISCVISD::VWADD_W_VL:
16987+ case RISCVISD::VWADDU_W_VL:
1699816988 case ISD::SUB:
1699916989 case RISCVISD::SUB_VL:
1700016990 case RISCVISD::VWSUB_W_VL:
1700116991 case RISCVISD::VWSUBU_W_VL:
16992+ case RISCVISD::VFWADD_W_VL:
1700216993 case RISCVISD::FSUB_VL:
1700316994 case RISCVISD::VFWSUB_W_VL:
1700416995 case ISD::SHL:
@@ -17117,6 +17108,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1711717108 Subtarget);
1711817109}
1711917110
17111+ /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17112+ ///
17113+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17114+ /// can be used to apply the pattern.
17115+ static std::optional<CombineResult>
17116+ canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17117+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17118+ const RISCVSubtarget &Subtarget) {
17119+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17120+ Subtarget);
17121+ }
17122+
17123+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17124+ ///
17125+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17126+ /// can be used to apply the pattern.
17127+ static std::optional<CombineResult>
17128+ canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17129+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17130+ const RISCVSubtarget &Subtarget) {
17131+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17132+ Subtarget);
17133+ }
17134+
1712017135/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1712117136///
1712217137/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17145,52 +17160,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1714517160 return std::nullopt;
1714617161}
1714717162
17148- /// Check if \p Root follows a pattern Root(sext(LHS), sext( RHS) )
17163+ /// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1714917164///
1715017165/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1715117166/// can be used to apply the pattern.
1715217167static std::optional<CombineResult>
1715317168canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1715417169 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1715517170 const RISCVSubtarget &Subtarget) {
17156- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17157- Subtarget);
17171+ if (LHS.SupportsSExt)
17172+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17173+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17174+ /*RHSExt=*/std::nullopt);
17175+ return std::nullopt;
1715817176}
1715917177
17160- /// Check if \p Root follows a pattern Root(zext(LHS), zext( RHS) )
17178+ /// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1716117179///
1716217180/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1716317181/// can be used to apply the pattern.
1716417182static std::optional<CombineResult>
1716517183canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1716617184 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1716717185 const RISCVSubtarget &Subtarget) {
17168- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17169- Subtarget);
17186+ if (LHS.SupportsZExt)
17187+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17188+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17189+ /*RHSExt=*/std::nullopt);
17190+ return std::nullopt;
1717017191}
1717117192
17172- /// Check if \p Root follows a pattern Root(fpext(LHS), fpext( RHS) )
17193+ /// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1717317194///
1717417195/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1717517196/// can be used to apply the pattern.
1717617197static std::optional<CombineResult>
1717717198canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1717817199 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1717917200 const RISCVSubtarget &Subtarget) {
17180- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17181- Subtarget);
17182- }
17183-
17184- /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17185- ///
17186- /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17187- /// can be used to apply the pattern.
17188- static std::optional<CombineResult>
17189- canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17190- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17191- const RISCVSubtarget &Subtarget) {
17192- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17193- Subtarget);
17201+ if (LHS.SupportsFPExt)
17202+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17203+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17204+ /*RHSExt=*/std::nullopt);
17205+ return std::nullopt;
1719417206}
1719517207
1719617208/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17233,7 +17245,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1723317245 case RISCVISD::VFNMSUB_VL:
1723417246 Strategies.push_back(canFoldToVWWithSameExtension);
1723517247 if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17236- Strategies.push_back(canFoldToVWWithBF16EXT );
17248+ Strategies.push_back(canFoldToVWWithSameExtBF16 );
1723717249 break;
1723817250 case ISD::MUL:
1723917251 case RISCVISD::MUL_VL:
@@ -17245,7 +17257,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1724517257 case ISD::SHL:
1724617258 case RISCVISD::SHL_VL:
1724717259 // shl -> vwsll
17248- Strategies.push_back(canFoldToVWWithZEXT );
17260+ Strategies.push_back(canFoldToVWWithSameExtZEXT );
1724917261 break;
1725017262 case RISCVISD::VWADD_W_VL:
1725117263 case RISCVISD::VWSUB_W_VL:
0 commit comments