@@ -13857,6 +13857,52 @@ static SDValue getMad64_32(SelectionDAG &DAG, const SDLoc &SL, EVT VT,
1385713857 return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad);
1385813858}
1385913859
13860+ // Fold
13861+ // y = lshr i64 x, 32
13862+ // res = add (mul i64 y, Constant), x where "Constant" is a 32 bit
13863+ // negative value
13864+ // To
13865+ // res = mad_u64_u32 y.lo ,Constant.lo, x.lo
13866+ static SDValue tryFoldMADwithSRL(SelectionDAG &DAG, const SDLoc &SL,
13867+ SDValue MulLHS, SDValue MulRHS,
13868+ SDValue AddRHS) {
13869+
13870+ if (MulLHS.getValueType() != MVT::i64)
13871+ return SDValue();
13872+
13873+ ConstantSDNode *ConstOp;
13874+ SDValue ShiftOp;
13875+ if (MulLHS.getOpcode() == ISD::SRL && MulRHS.getOpcode() == ISD::Constant) {
13876+ ConstOp = cast<ConstantSDNode>(MulRHS.getNode());
13877+ ShiftOp = MulLHS;
13878+ } else if (MulRHS.getOpcode() == ISD::SRL &&
13879+ MulLHS.getOpcode() == ISD::Constant) {
13880+ ConstOp = cast<ConstantSDNode>(MulLHS.getNode());
13881+ ShiftOp = MulRHS;
13882+ } else
13883+ return SDValue();
13884+
13885+ if (ShiftOp.getOperand(1).getOpcode() != ISD::Constant ||
13886+ AddRHS != ShiftOp.getOperand(0))
13887+ return SDValue();
13888+
13889+ if (cast<ConstantSDNode>(ShiftOp->getOperand(1))->getAsZExtVal() != 32)
13890+ return SDValue();
13891+
13892+ APInt ConstVal = ConstOp->getAPIntValue();
13893+ if (!ConstVal.isNegative() || !ConstVal.isSignedIntN(33))
13894+ return SDValue();
13895+
13896+ SDValue Zero = DAG.getConstant(0, SL, MVT::i32);
13897+ SDValue ConstMul = DAG.getConstant(
13898+ ConstVal.getZExtValue() & 0x00000000FFFFFFFF, SL, MVT::i32);
13899+ AddRHS = DAG.getNode(ISD::AND, SL, MVT::i64, AddRHS,
13900+ DAG.getConstant(0x00000000FFFFFFFF, SL, MVT::i64));
13901+ return getMad64_32(DAG, SL, MVT::i64,
13902+ DAG.getNode(ISD::TRUNCATE, SL, MVT::i32, MulLHS), ConstMul,
13903+ AddRHS, false);
13904+ }
13905+
1386013906// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z) plus high
1386113907// multiplies, if any.
1386213908//
@@ -13915,6 +13961,9 @@ SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N,
1391513961 SDValue MulRHS = LHS.getOperand(1);
1391613962 SDValue AddRHS = RHS;
1391713963
13964+ if (SDValue FoldedMAD = tryFoldMADwithSRL(DAG, SL, MulLHS, MulRHS, AddRHS))
13965+ return FoldedMAD;
13966+
1391813967 // Always check whether operands are small unsigned values, since that
1391913968 // knowledge is useful in more cases. Check for small signed values only if
1392013969 // doing so can unlock a shorter code sequence.
0 commit comments