@@ -17759,6 +17759,83 @@ static SDValue combineScalarCTPOPToVCPOP(SDNode *N, SelectionDAG &DAG,
1775917759 return DAG.getZExtOrTrunc(Pop, DL, VT);
1776017760}
1776117761
17762+ static SDValue performSHLCombine(SDNode *N,
17763+ TargetLowering::DAGCombinerInfo &DCI,
17764+ const RISCVSubtarget &Subtarget) {
17765+ // (shl (zext x), y) -> (vwsll x, y)
17766+ if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
17767+ return V;
17768+
17769+ // (shl (sext x), C) -> (vwmulsu x, 1u << C)
17770+ // (shl (zext x), C) -> (vwmulu x, 1u << C)
17771+
17772+ if (!DCI.isAfterLegalizeDAG())
17773+ return SDValue();
17774+
17775+ SDValue LHS = N->getOperand(0);
17776+ if (!LHS.hasOneUse())
17777+ return SDValue();
17778+ unsigned Opcode;
17779+ switch (LHS.getOpcode()) {
17780+ case ISD::SIGN_EXTEND:
17781+ case RISCVISD::VSEXT_VL:
17782+ Opcode = RISCVISD::VWMULSU_VL;
17783+ break;
17784+ case ISD::ZERO_EXTEND:
17785+ case RISCVISD::VZEXT_VL:
17786+ Opcode = RISCVISD::VWMULU_VL;
17787+ break;
17788+ default:
17789+ return SDValue();
17790+ }
17791+
17792+ SDValue RHS = N->getOperand(1);
17793+ APInt ShAmt;
17794+ uint64_t ShAmtInt;
17795+ if (ISD::isConstantSplatVector(RHS.getNode(), ShAmt))
17796+ ShAmtInt = ShAmt.getZExtValue();
17797+ else if (RHS.getOpcode() == RISCVISD::VMV_V_X_VL &&
17798+ RHS.getOperand(1).getOpcode() == ISD::Constant)
17799+ ShAmtInt = RHS.getConstantOperandVal(1);
17800+ else
17801+ return SDValue();
17802+
17803+ // Better foldings:
17804+ // (shl (sext x), 1) -> (vwadd x, x)
17805+ // (shl (zext x), 1) -> (vwaddu x, x)
17806+ if (ShAmtInt <= 1)
17807+ return SDValue();
17808+
17809+ SDValue NarrowOp = LHS.getOperand(0);
17810+ MVT NarrowVT = NarrowOp.getSimpleValueType();
17811+ uint64_t NarrowBits = NarrowVT.getScalarSizeInBits();
17812+ if (ShAmtInt >= NarrowBits)
17813+ return SDValue();
17814+ MVT VT = N->getSimpleValueType(0);
17815+ if (NarrowBits * 2 != VT.getScalarSizeInBits())
17816+ return SDValue();
17817+
17818+ SelectionDAG &DAG = DCI.DAG;
17819+ SDLoc DL(N);
17820+ SDValue Passthru, Mask, VL;
17821+ switch (N->getOpcode()) {
17822+ case ISD::SHL:
17823+ Passthru = DAG.getUNDEF(VT);
17824+ std::tie(Mask, VL) = getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
17825+ break;
17826+ case RISCVISD::SHL_VL:
17827+ Passthru = N->getOperand(2);
17828+ Mask = N->getOperand(3);
17829+ VL = N->getOperand(4);
17830+ break;
17831+ default:
17832+ llvm_unreachable("Expected SHL");
17833+ }
17834+ return DAG.getNode(Opcode, DL, VT, NarrowOp,
17835+ DAG.getConstant(1ULL << ShAmtInt, SDLoc(RHS), NarrowVT),
17836+ Passthru, Mask, VL);
17837+ }
17838+
1776217839SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1776317840 DAGCombinerInfo &DCI) const {
1776417841 SelectionDAG &DAG = DCI.DAG;
@@ -18392,7 +18469,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1839218469 break;
1839318470 }
1839418471 case RISCVISD::SHL_VL:
18395- if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
18472+ if (SDValue V = performSHLCombine (N, DCI, Subtarget))
1839618473 return V;
1839718474 [[fallthrough]];
1839818475 case RISCVISD::SRA_VL:
@@ -18417,7 +18494,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1841718494 case ISD::SRL:
1841818495 case ISD::SHL: {
1841918496 if (N->getOpcode() == ISD::SHL) {
18420- if (SDValue V = combineOp_VLToVWOp_VL (N, DCI, Subtarget))
18497+ if (SDValue V = performSHLCombine (N, DCI, Subtarget))
1842118498 return V;
1842218499 }
1842318500 SDValue ShAmt = N->getOperand(1);
0 commit comments