@@ -44958,19 +44958,18 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4495844958 }
4495944959 case X86ISD::VPMADD52L:
4496044960 case X86ISD::VPMADD52H: {
44961- KnownBits OrigKnownOp0, OrigKnownOp1;
4496244961 KnownBits KnownOp0, KnownOp1, KnownOp2;
4496344962 SDValue Op0 = Op.getOperand(0);
4496444963 SDValue Op1 = Op.getOperand(1);
4496544964 SDValue Op2 = Op.getOperand(2);
4496644965 // Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
4496744966 // operand 2).
4496844967 APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
44969- if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0 ,
44968+ if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0 ,
4497044969 TLO, Depth + 1))
4497144970 return true;
4497244971
44973- if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1 ,
44972+ if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1 ,
4497444973 TLO, Depth + 1))
4497544974 return true;
4497644975
@@ -44979,8 +44978,8 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4497944978 return true;
4498044979
4498144980 KnownBits KnownMul;
44982- KnownOp0 = OrigKnownOp0 .trunc(52);
44983- KnownOp1 = OrigKnownOp1 .trunc(52);
44981+ KnownOp0 = KnownOp0 .trunc(52);
44982+ KnownOp1 = KnownOp1 .trunc(52);
4498444983 KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
4498544984 : KnownBits::mulhu(KnownOp0, KnownOp1);
4498644985 KnownMul = KnownMul.zext(64);
@@ -44992,20 +44991,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4499244991 return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
4499344992 }
4499444993
44995- // C * X --> X * C
44996- if (KnownOp0.isConstant()) {
44997- std::swap(OrigKnownOp0, OrigKnownOp1);
44998- std::swap(KnownOp0, KnownOp1);
44999- std::swap(Op0, Op1);
45000- }
45001-
45002- // lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
45003- if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
45004- KnownOp1.getConstant().isOne() &&
45005- OrigKnownOp0.countMinLeadingZeros() >= 12) {
45006- return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
45007- }
45008-
4500944994 Known = KnownBits::add(KnownMul, KnownOp2);
4501044995 return false;
4501144996 }
@@ -60201,8 +60186,37 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
6020160186static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG,
6020260187 TargetLowering::DAGCombinerInfo &DCI) {
6020360188 MVT VT = N->getSimpleValueType(0);
60204- unsigned NumEltBits = VT.getScalarSizeInBits();
60189+
60190+ bool AddLow = N->getOpcode() == X86ISD::VPMADD52L;
60191+ SDValue Op0 = N->getOperand(0);
60192+ SDValue Op1 = N->getOperand(1);
60193+ SDValue Op2 = N->getOperand(2);
60194+ SDLoc DL(N);
60195+
60196+ APInt C0, C1;
60197+ bool HasC0 = X86::isConstantSplat(Op0, C0),
60198+ HasC1 = X86::isConstantSplat(Op1, C1);
60199+
60200+ // lo/hi(C * X) + Z --> lo/hi(X * C) + Z
60201+ if (HasC0 && !HasC1)
60202+ return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2);
60203+
60204+ // Only keep the low 52 bits of C1
60205+ if (HasC1 && C1.countLeadingZeros() < 12) {
60206+ C1.clearBits(52, 64);
60207+ SDValue LowC1 = DAG.getConstant(C1, DL, VT);
60208+ return DAG.getNode(N->getOpcode(), DL, VT, Op0, LowC1, Op2);
60209+ }
60210+
60211+ // lo(X * 1) + Z --> lo(X) + Z iff X == lo(X)
60212+ if (AddLow && HasC1 && C1.isOne()) {
60213+ KnownBits KnownOp0 = DAG.computeKnownBits(Op0);
60214+ if (KnownOp0.countMinLeadingZeros() >= 12)
60215+ return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2);
60216+ }
60217+
6020560218 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
60219+ unsigned NumEltBits = VT.getScalarSizeInBits();
6020660220 if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
6020760221 DCI))
6020860222 return SDValue(N, 0);
0 commit comments