@@ -17325,18 +17325,9 @@ struct NodeExtensionHelper {
1732517325 case RISCVISD::VWSUBU_W_VL:
1732617326 case RISCVISD::VFWADD_W_VL:
1732717327 case RISCVISD::VFWSUB_W_VL:
17328- if (OperandIdx == 1) {
17329- SupportsZExt =
17330- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
17331- SupportsSExt =
17332- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
17333- SupportsFPExt =
17334- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
17335- // There's no existing extension here, so we don't have to worry about
17336- // making sure it gets removed.
17337- EnforceOneUse = false;
17328+ // Operand 1 can't be changed.
17329+ if (OperandIdx == 1)
1733817330 break;
17339- }
1734017331 [[fallthrough]];
1734117332 default:
1734217333 fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17374,20 +17365,20 @@ struct NodeExtensionHelper {
1737417365 case RISCVISD::ADD_VL:
1737517366 case RISCVISD::MUL_VL:
1737617367 case RISCVISD::OR_VL:
17377- case RISCVISD::VWADD_W_VL:
17378- case RISCVISD::VWADDU_W_VL:
1737917368 case RISCVISD::FADD_VL:
1738017369 case RISCVISD::FMUL_VL:
17381- case RISCVISD::VFWADD_W_VL:
1738217370 case RISCVISD::VFMADD_VL:
1738317371 case RISCVISD::VFNMSUB_VL:
1738417372 case RISCVISD::VFNMADD_VL:
1738517373 case RISCVISD::VFMSUB_VL:
1738617374 return true;
17375+ case RISCVISD::VWADD_W_VL:
17376+ case RISCVISD::VWADDU_W_VL:
1738717377 case ISD::SUB:
1738817378 case RISCVISD::SUB_VL:
1738917379 case RISCVISD::VWSUB_W_VL:
1739017380 case RISCVISD::VWSUBU_W_VL:
17381+ case RISCVISD::VFWADD_W_VL:
1739117382 case RISCVISD::FSUB_VL:
1739217383 case RISCVISD::VFWSUB_W_VL:
1739317384 case ISD::SHL:
@@ -17506,6 +17497,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1750617497 Subtarget);
1750717498}
1750817499
17500+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17501+ ///
17502+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17503+ /// can be used to apply the pattern.
17504+ static std::optional<CombineResult>
17505+ canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17506+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17507+ const RISCVSubtarget &Subtarget) {
17508+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17509+ Subtarget);
17510+ }
17511+
17512+ /// Check if \p Root follows a pattern Root(ext(LHS), zext(RHS))
17513+ ///
17514+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17515+ /// can be used to apply the pattern.
17516+ static std::optional<CombineResult>
17517+ canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17518+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17519+ const RISCVSubtarget &Subtarget) {
17520+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17521+ Subtarget);
17522+ }
17523+
1750917524/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1751017525///
1751117526/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17534,52 +17549,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1753417549 return std::nullopt;
1753517550}
1753617551
17537- /// Check if \p Root follows a pattern Root(sext(LHS), sext( RHS) )
17552+ /// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1753817553///
1753917554/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1754017555/// can be used to apply the pattern.
1754117556static std::optional<CombineResult>
1754217557canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1754317558 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1754417559 const RISCVSubtarget &Subtarget) {
17545- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17546- Subtarget);
17560+ if (LHS.SupportsSExt)
17561+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17562+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17563+ /*RHSExt=*/std::nullopt);
17564+ return std::nullopt;
1754717565}
1754817566
17549- /// Check if \p Root follows a pattern Root(zext(LHS), zext( RHS) )
17567+ /// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1755017568///
1755117569/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1755217570/// can be used to apply the pattern.
1755317571static std::optional<CombineResult>
1755417572canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1755517573 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1755617574 const RISCVSubtarget &Subtarget) {
17557- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17558- Subtarget);
17575+ if (LHS.SupportsZExt)
17576+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17577+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17578+ /*RHSExt=*/std::nullopt);
17579+ return std::nullopt;
1755917580}
1756017581
17561- /// Check if \p Root follows a pattern Root(fpext(LHS), fpext( RHS) )
17582+ /// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1756217583///
1756317584/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1756417585/// can be used to apply the pattern.
1756517586static std::optional<CombineResult>
1756617587canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1756717588 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1756817589 const RISCVSubtarget &Subtarget) {
17569- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17570- Subtarget);
17571- }
17572-
17573- /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17574- ///
17575- /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17576- /// can be used to apply the pattern.
17577- static std::optional<CombineResult>
17578- canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17579- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17580- const RISCVSubtarget &Subtarget) {
17581- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17582- Subtarget);
17590+ if (LHS.SupportsFPExt)
17591+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17592+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17593+ /*RHSExt=*/std::nullopt);
17594+ return std::nullopt;
1758317595}
1758417596
1758517597/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17622,7 +17634,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1762217634 case RISCVISD::VFNMSUB_VL:
1762317635 Strategies.push_back(canFoldToVWWithSameExtension);
1762417636 if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17625- Strategies.push_back(canFoldToVWWithBF16EXT );
17637+ Strategies.push_back(canFoldToVWWithSameExtBF16 );
1762617638 break;
1762717639 case ISD::MUL:
1762817640 case RISCVISD::MUL_VL:
@@ -17634,7 +17646,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1763417646 case ISD::SHL:
1763517647 case RISCVISD::SHL_VL:
1763617648 // shl -> vwsll
17637- Strategies.push_back(canFoldToVWWithZEXT );
17649+ Strategies.push_back(canFoldToVWWithSameExtZEXT );
1763817650 break;
1763917651 case RISCVISD::VWADD_W_VL:
1764017652 case RISCVISD::VWSUB_W_VL:
0 commit comments