@@ -17365,18 +17365,9 @@ struct NodeExtensionHelper {
1736517365 case RISCVISD::VWSUBU_W_VL:
1736617366 case RISCVISD::VFWADD_W_VL:
1736717367 case RISCVISD::VFWSUB_W_VL:
17368- if (OperandIdx == 1) {
17369- SupportsZExt =
17370- Opc == RISCVISD::VWADDU_W_VL || Opc == RISCVISD::VWSUBU_W_VL;
17371- SupportsSExt =
17372- Opc == RISCVISD::VWADD_W_VL || Opc == RISCVISD::VWSUB_W_VL;
17373- SupportsFPExt =
17374- Opc == RISCVISD::VFWADD_W_VL || Opc == RISCVISD::VFWSUB_W_VL;
17375- // There's no existing extension here, so we don't have to worry about
17376- // making sure it gets removed.
17377- EnforceOneUse = false;
17368+ // Operand 1 can't be changed.
17369+ if (OperandIdx == 1)
1737817370 break;
17379- }
1738017371 [[fallthrough]];
1738117372 default:
1738217373 fillUpExtensionSupport(Root, DAG, Subtarget);
@@ -17414,20 +17405,20 @@ struct NodeExtensionHelper {
1741417405 case RISCVISD::ADD_VL:
1741517406 case RISCVISD::MUL_VL:
1741617407 case RISCVISD::OR_VL:
17417- case RISCVISD::VWADD_W_VL:
17418- case RISCVISD::VWADDU_W_VL:
1741917408 case RISCVISD::FADD_VL:
1742017409 case RISCVISD::FMUL_VL:
17421- case RISCVISD::VFWADD_W_VL:
1742217410 case RISCVISD::VFMADD_VL:
1742317411 case RISCVISD::VFNMSUB_VL:
1742417412 case RISCVISD::VFNMADD_VL:
1742517413 case RISCVISD::VFMSUB_VL:
1742617414 return true;
17415+ case RISCVISD::VWADD_W_VL:
17416+ case RISCVISD::VWADDU_W_VL:
1742717417 case ISD::SUB:
1742817418 case RISCVISD::SUB_VL:
1742917419 case RISCVISD::VWSUB_W_VL:
1743017420 case RISCVISD::VWSUBU_W_VL:
17421+ case RISCVISD::VFWADD_W_VL:
1743117422 case RISCVISD::FSUB_VL:
1743217423 case RISCVISD::VFWSUB_W_VL:
1743317424 case ISD::SHL:
@@ -17546,6 +17537,30 @@ canFoldToVWWithSameExtension(SDNode *Root, const NodeExtensionHelper &LHS,
1754617537 Subtarget);
1754717538}
1754817539
17540+ /// Check if \p Root follows a pattern Root(zext(LHS), zext(RHS))
17541+ ///
17542+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17543+ /// can be used to apply the pattern.
17544+ static std::optional<CombineResult>
17545+ canFoldToVWWithSameExtZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
17546+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17547+ const RISCVSubtarget &Subtarget) {
17548+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17549+ Subtarget);
17550+ }
17551+
17552+ /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17553+ ///
17554+ /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17555+ /// can be used to apply the pattern.
17556+ static std::optional<CombineResult>
17557+ canFoldToVWWithSameExtBF16(SDNode *Root, const NodeExtensionHelper &LHS,
17558+ const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17559+ const RISCVSubtarget &Subtarget) {
17560+ return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17561+ Subtarget);
17562+ }
17563+
1754917564/// Check if \p Root follows a pattern Root(LHS, ext(RHS))
1755017565///
1755117566/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
@@ -17574,52 +17589,49 @@ canFoldToVW_W(SDNode *Root, const NodeExtensionHelper &LHS,
1757417589 return std::nullopt;
1757517590}
1757617591
17577- /// Check if \p Root follows a pattern Root(sext(LHS), sext( RHS) )
17592+ /// Check if \p Root follows a pattern Root(sext(LHS), RHS)
1757817593///
1757917594/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1758017595/// can be used to apply the pattern.
1758117596static std::optional<CombineResult>
1758217597canFoldToVWWithSEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1758317598 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1758417599 const RISCVSubtarget &Subtarget) {
17585- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::SExt, DAG,
17586- Subtarget);
17600+ if (LHS.SupportsSExt)
17601+ return CombineResult(NodeExtensionHelper::getSExtOpcode(Root->getOpcode()),
17602+ Root, LHS, /*LHSExt=*/{ExtKind::SExt}, RHS,
17603+ /*RHSExt=*/std::nullopt);
17604+ return std::nullopt;
1758717605}
1758817606
17589- /// Check if \p Root follows a pattern Root(zext(LHS), zext( RHS) )
17607+ /// Check if \p Root follows a pattern Root(zext(LHS), RHS)
1759017608///
1759117609/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1759217610/// can be used to apply the pattern.
1759317611static std::optional<CombineResult>
1759417612canFoldToVWWithZEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1759517613 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1759617614 const RISCVSubtarget &Subtarget) {
17597- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::ZExt, DAG,
17598- Subtarget);
17615+ if (LHS.SupportsZExt)
17616+ return CombineResult(NodeExtensionHelper::getZExtOpcode(Root->getOpcode()),
17617+ Root, LHS, /*LHSExt=*/{ExtKind::ZExt}, RHS,
17618+ /*RHSExt=*/std::nullopt);
17619+ return std::nullopt;
1759917620}
1760017621
17601- /// Check if \p Root follows a pattern Root(fpext(LHS), fpext( RHS) )
17622+ /// Check if \p Root follows a pattern Root(fpext(LHS), RHS)
1760217623///
1760317624/// \returns std::nullopt if the pattern doesn't match or a CombineResult that
1760417625/// can be used to apply the pattern.
1760517626static std::optional<CombineResult>
1760617627canFoldToVWWithFPEXT(SDNode *Root, const NodeExtensionHelper &LHS,
1760717628 const NodeExtensionHelper &RHS, SelectionDAG &DAG,
1760817629 const RISCVSubtarget &Subtarget) {
17609- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::FPExt, DAG,
17610- Subtarget);
17611- }
17612-
17613- /// Check if \p Root follows a pattern Root(bf16ext(LHS), bf16ext(RHS))
17614- ///
17615- /// \returns std::nullopt if the pattern doesn't match or a CombineResult that
17616- /// can be used to apply the pattern.
17617- static std::optional<CombineResult>
17618- canFoldToVWWithBF16EXT(SDNode *Root, const NodeExtensionHelper &LHS,
17619- const NodeExtensionHelper &RHS, SelectionDAG &DAG,
17620- const RISCVSubtarget &Subtarget) {
17621- return canFoldToVWWithSameExtensionImpl(Root, LHS, RHS, ExtKind::BF16Ext, DAG,
17622- Subtarget);
17630+ if (LHS.SupportsFPExt)
17631+ return CombineResult(NodeExtensionHelper::getFPExtOpcode(Root->getOpcode()),
17632+ Root, LHS, /*LHSExt=*/{ExtKind::FPExt}, RHS,
17633+ /*RHSExt=*/std::nullopt);
17634+ return std::nullopt;
1762317635}
1762417636
1762517637/// Check if \p Root follows a pattern Root(sext(LHS), zext(RHS))
@@ -17662,7 +17674,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1766217674 case RISCVISD::VFNMSUB_VL:
1766317675 Strategies.push_back(canFoldToVWWithSameExtension);
1766417676 if (Root->getOpcode() == RISCVISD::VFMADD_VL)
17665- Strategies.push_back(canFoldToVWWithBF16EXT );
17677+ Strategies.push_back(canFoldToVWWithSameExtBF16 );
1766617678 break;
1766717679 case ISD::MUL:
1766817680 case RISCVISD::MUL_VL:
@@ -17674,7 +17686,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1767417686 case ISD::SHL:
1767517687 case RISCVISD::SHL_VL:
1767617688 // shl -> vwsll
17677- Strategies.push_back(canFoldToVWWithZEXT );
17689+ Strategies.push_back(canFoldToVWWithSameExtZEXT );
1767817690 break;
1767917691 case RISCVISD::VWADD_W_VL:
1768017692 case RISCVISD::VWSUB_W_VL:
0 commit comments