Skip to content

Commit 93ec08d

Browse files
committed
[DAG] Move SIGN_EXTEND_INREG constant folding inside FoldConstantArithmetic
Update visitSIGN_EXTEND_INREG to call FoldConstantArithmetic instead of getNode.
1 parent 06fce61 commit 93ec08d

File tree

2 files changed

+41
-37
lines changed

2 files changed

+41
-37
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14819,8 +14819,9 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) {
1481914819
return DAG.getConstant(0, DL, VT);
1482014820

1482114821
// fold (sext_in_reg c1) -> c1
14822-
if (DAG.isConstantIntBuildVectorOrConstantInt(N0))
14823-
return DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N0, N1);
14822+
if (SDValue C =
14823+
DAG.FoldConstantArithmetic(ISD::SIGN_EXTEND_INREG, DL, VT, {N0, N1}))
14824+
return C;
1482414825

1482514826
// If the input is already sign extended, just drop the extension.
1482614827
if (ExtVTBits >= DAG.ComputeMaxSignificantBits(N0))

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 38 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6659,6 +6659,44 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
66596659
if (TLI->isCommutativeBinOp(Opcode))
66606660
if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1]))
66616661
return FoldSymbolOffset(Opcode, VT, GA, Ops[0].getNode());
6662+
6663+
// fold (sext_in_reg c1) -> c2
6664+
if (Opcode == ISD::SIGN_EXTEND_INREG) {
6665+
EVT EVT = cast<VTSDNode>(Ops[1])->getVT();
6666+
6667+
auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
6668+
unsigned FromBits = EVT.getScalarSizeInBits();
6669+
Val <<= Val.getBitWidth() - FromBits;
6670+
Val.ashrInPlace(Val.getBitWidth() - FromBits);
6671+
return getConstant(Val, DL, ConstantVT);
6672+
};
6673+
6674+
if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0])) {
6675+
const APInt &Val = C1->getAPIntValue();
6676+
return SignExtendInReg(Val, VT);
6677+
}
6678+
6679+
if (ISD::isBuildVectorOfConstantSDNodes(Ops[0].getNode())) {
6680+
SmallVector<SDValue, 8> ScalarOps;
6681+
llvm::EVT OpVT = Ops[0].getOperand(0).getValueType();
6682+
for (int I = 0, E = VT.getVectorNumElements(); I != E; ++I) {
6683+
SDValue Op = Ops[0].getOperand(I);
6684+
if (Op.isUndef()) {
6685+
ScalarOps.push_back(getUNDEF(OpVT));
6686+
continue;
6687+
}
6688+
APInt Val = cast<ConstantSDNode>(Op)->getAPIntValue();
6689+
ScalarOps.push_back(SignExtendInReg(Val, OpVT));
6690+
}
6691+
return getBuildVector(VT, DL, ScalarOps);
6692+
}
6693+
6694+
if (Ops[0].getOpcode() == ISD::SPLAT_VECTOR &&
6695+
isa<ConstantSDNode>(Ops[0].getOperand(0)))
6696+
return getNode(ISD::SPLAT_VECTOR, DL, VT,
6697+
SignExtendInReg(Ops[0].getConstantOperandAPInt(0),
6698+
Ops[0].getOperand(0).getValueType()));
6699+
}
66626700
}
66636701

66646702
// This is for vector folding only from here on.
@@ -7205,41 +7243,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
72057243
"Vector element counts must match in SIGN_EXTEND_INREG");
72067244
assert(EVT.bitsLE(VT) && "Not extending!");
72077245
if (EVT == VT) return N1; // Not actually extending
7208-
7209-
auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
7210-
unsigned FromBits = EVT.getScalarSizeInBits();
7211-
Val <<= Val.getBitWidth() - FromBits;
7212-
Val.ashrInPlace(Val.getBitWidth() - FromBits);
7213-
return getConstant(Val, DL, ConstantVT);
7214-
};
7215-
7216-
if (N1C) {
7217-
const APInt &Val = N1C->getAPIntValue();
7218-
return SignExtendInReg(Val, VT);
7219-
}
7220-
7221-
if (ISD::isBuildVectorOfConstantSDNodes(N1.getNode())) {
7222-
SmallVector<SDValue, 8> Ops;
7223-
llvm::EVT OpVT = N1.getOperand(0).getValueType();
7224-
for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
7225-
SDValue Op = N1.getOperand(i);
7226-
if (Op.isUndef()) {
7227-
Ops.push_back(getUNDEF(OpVT));
7228-
continue;
7229-
}
7230-
ConstantSDNode *C = cast<ConstantSDNode>(Op);
7231-
APInt Val = C->getAPIntValue();
7232-
Ops.push_back(SignExtendInReg(Val, OpVT));
7233-
}
7234-
return getBuildVector(VT, DL, Ops);
7235-
}
7236-
7237-
if (N1.getOpcode() == ISD::SPLAT_VECTOR &&
7238-
isa<ConstantSDNode>(N1.getOperand(0)))
7239-
return getNode(
7240-
ISD::SPLAT_VECTOR, DL, VT,
7241-
SignExtendInReg(N1.getConstantOperandAPInt(0),
7242-
N1.getOperand(0).getValueType()));
72437246
break;
72447247
}
72457248
case ISD::FP_TO_SINT_SAT:

0 commit comments

Comments
 (0)