@@ -16619,6 +16619,25 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1661916619}
1662016620} // End anonymous namespace.
1662116621
16622+ static SDValue simplifyOp_VL(SDNode *N) {
16623+ // TODO: Extend this to other binops using generic identity logic
16624+ assert(N->getOpcode() == RISCVISD::ADD_VL);
16625+ SDValue A = N->getOperand(0);
16626+ SDValue B = N->getOperand(1);
16627+ SDValue Passthru = N->getOperand(2);
16628+ if (!Passthru.isUndef())
16629+ // TODO:This could be a vmerge instead
16630+ return SDValue();
16631+ ;
16632+ if (ISD::isConstantSplatVectorAllZeros(B.getNode()))
16633+ return A;
16634+ // Peek through fixed to scalable
16635+ if (B.getOpcode() == ISD::INSERT_SUBVECTOR && B.getOperand(0).isUndef() &&
16636+ ISD::isConstantSplatVectorAllZeros(B.getOperand(1).getNode()))
16637+ return A;
16638+ return SDValue();
16639+ }
16640+
1662216641/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
1662316642/// The supported combines are:
1662416643/// add | add_vl | or disjoint | or_vl disjoint -> vwadd(u) | vwadd(u)_w
@@ -18515,20 +18534,10 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1851518534 return SDValue();
1851618535
1851718536 SDValue AccumOp = DotOp.getOperand(2);
18518- bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
18519- // Peek through fixed to scalable
18520- if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
18521- AccumOp.getOperand(0).isUndef())
18522- IsNullAdd =
18523- ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
18524-
1852518537 SDLoc DL(N);
1852618538 EVT VT = N->getValueType(0);
18527- // The manual constant folding is required, this case is not constant folded
18528- // or combined.
18529- if (!IsNullAdd)
18530- Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
18531- DAG.getUNDEF(VT), AddMask, AddVL);
18539+ Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, Addend, AccumOp,
18540+ DAG.getUNDEF(VT), AddMask, AddVL);
1853218541
1853318542 SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
1853418543 DotOp.getOperand(3), DotOp->getOperand(4)};
@@ -19657,6 +19666,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1965719666 break;
1965819667 }
1965919668 case RISCVISD::ADD_VL:
19669+ if (SDValue V = simplifyOp_VL(N))
19670+ return V;
1966019671 if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1966119672 return V;
1966219673 if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
0 commit comments