Skip to content

Commit 574fe4c

Browse files
committed
combine
1 parent a1bd288 commit 574fe4c

File tree

1 file changed

+34
-20
lines changed

1 file changed

+34
-20
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44958,19 +44958,18 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4495844958
}
4495944959
case X86ISD::VPMADD52L:
4496044960
case X86ISD::VPMADD52H: {
44961-
KnownBits OrigKnownOp0, OrigKnownOp1;
4496244961
KnownBits KnownOp0, KnownOp1, KnownOp2;
4496344962
SDValue Op0 = Op.getOperand(0);
4496444963
SDValue Op1 = Op.getOperand(1);
4496544964
SDValue Op2 = Op.getOperand(2);
4496644965
// Only demand the lower 52-bits of operands 0 / 1 (and all 64-bits of
4496744966
// operand 2).
4496844967
APInt Low52Bits = APInt::getLowBitsSet(BitWidth, 52);
44969-
if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, OrigKnownOp0,
44968+
if (SimplifyDemandedBits(Op0, Low52Bits, OriginalDemandedElts, KnownOp0,
4497044969
TLO, Depth + 1))
4497144970
return true;
4497244971

44973-
if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, OrigKnownOp1,
44972+
if (SimplifyDemandedBits(Op1, Low52Bits, OriginalDemandedElts, KnownOp1,
4497444973
TLO, Depth + 1))
4497544974
return true;
4497644975

@@ -44979,8 +44978,8 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4497944978
return true;
4498044979

4498144980
KnownBits KnownMul;
44982-
KnownOp0 = OrigKnownOp0.trunc(52);
44983-
KnownOp1 = OrigKnownOp1.trunc(52);
44981+
KnownOp0 = KnownOp0.trunc(52);
44982+
KnownOp1 = KnownOp1.trunc(52);
4498444983
KnownMul = Opc == X86ISD::VPMADD52L ? KnownBits::mul(KnownOp0, KnownOp1)
4498544984
: KnownBits::mulhu(KnownOp0, KnownOp1);
4498644985
KnownMul = KnownMul.zext(64);
@@ -44992,20 +44991,6 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode(
4499244991
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, C, Op2));
4499344992
}
4499444993

44995-
// C * X --> X * C
44996-
if (KnownOp0.isConstant()) {
44997-
std::swap(OrigKnownOp0, OrigKnownOp1);
44998-
std::swap(KnownOp0, KnownOp1);
44999-
std::swap(Op0, Op1);
45000-
}
45001-
45002-
// lo(X * 1) + Z --> lo(X) + Z --> X iff X == lo(X)
45003-
if (Opc == X86ISD::VPMADD52L && KnownOp1.isConstant() &&
45004-
KnownOp1.getConstant().isOne() &&
45005-
OrigKnownOp0.countMinLeadingZeros() >= 12) {
45006-
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::ADD, DL, VT, Op0, Op2));
45007-
}
45008-
4500944994
Known = KnownBits::add(KnownMul, KnownOp2);
4501044995
return false;
4501144996
}
@@ -60201,8 +60186,37 @@ static SDValue combineVPMADD(SDNode *N, SelectionDAG &DAG,
6020160186
static SDValue combineVPMADD52LH(SDNode *N, SelectionDAG &DAG,
6020260187
TargetLowering::DAGCombinerInfo &DCI) {
6020360188
MVT VT = N->getSimpleValueType(0);
60204-
unsigned NumEltBits = VT.getScalarSizeInBits();
60189+
60190+
bool AddLow = N->getOpcode() == X86ISD::VPMADD52L;
60191+
SDValue Op0 = N->getOperand(0);
60192+
SDValue Op1 = N->getOperand(1);
60193+
SDValue Op2 = N->getOperand(2);
60194+
SDLoc DL(N);
60195+
60196+
APInt C0, C1;
60197+
bool HasC0 = X86::isConstantSplat(Op0, C0),
60198+
HasC1 = X86::isConstantSplat(Op1, C1);
60199+
60200+
// lo/hi(C * X) + Z --> lo/hi(X * C) + Z
60201+
if (HasC0 && !HasC1)
60202+
return DAG.getNode(N->getOpcode(), DL, VT, Op1, Op0, Op2);
60203+
60204+
// Only keep the low 52 bits of C1
60205+
if (HasC1 && C1.countLeadingZeros() < 12) {
60206+
C1.clearBits(52, 64);
60207+
SDValue LowC1 = DAG.getConstant(C1, DL, VT);
60208+
return DAG.getNode(N->getOpcode(), DL, VT, Op0, LowC1, Op2);
60209+
}
60210+
60211+
// lo(X * 1) + Z --> lo(X) + Z iff X == lo(X)
60212+
if (AddLow && HasC1 && C1.isOne()) {
60213+
KnownBits KnownOp0 = DAG.computeKnownBits(Op0);
60214+
if (KnownOp0.countMinLeadingZeros() >= 12)
60215+
return DAG.getNode(ISD::ADD, DL, VT, Op0, Op2);
60216+
}
60217+
6020560218
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
60219+
unsigned NumEltBits = VT.getScalarSizeInBits();
6020660220
if (TLI.SimplifyDemandedBits(SDValue(N, 0), APInt::getAllOnes(NumEltBits),
6020760221
DCI))
6020860222
return SDValue(N, 0);

0 commit comments

Comments
 (0)