@@ -6659,6 +6659,44 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
66596659 if (TLI->isCommutativeBinOp (Opcode))
66606660 if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1 ]))
66616661 return FoldSymbolOffset (Opcode, VT, GA, Ops[0 ].getNode ());
6662+
6663+ // fold (sext_in_reg c1) -> c2
6664+ if (Opcode == ISD::SIGN_EXTEND_INREG) {
6665+ EVT EVT = cast<VTSDNode>(Ops[1 ])->getVT ();
6666+
6667+ auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
6668+ unsigned FromBits = EVT.getScalarSizeInBits ();
6669+ Val <<= Val.getBitWidth () - FromBits;
6670+ Val.ashrInPlace (Val.getBitWidth () - FromBits);
6671+ return getConstant (Val, DL, ConstantVT);
6672+ };
6673+
6674+ if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0 ])) {
6675+ const APInt &Val = C1->getAPIntValue ();
6676+ return SignExtendInReg (Val, VT);
6677+ }
6678+
6679+ if (ISD::isBuildVectorOfConstantSDNodes (Ops[0 ].getNode ())) {
6680+ SmallVector<SDValue, 8 > ScalarOps;
6681+ llvm::EVT OpVT = Ops[0 ].getOperand (0 ).getValueType ();
6682+ for (int I = 0 , E = VT.getVectorNumElements (); I != E; ++I) {
6683+ SDValue Op = Ops[0 ].getOperand (I);
6684+ if (Op.isUndef ()) {
6685+ ScalarOps.push_back (getUNDEF (OpVT));
6686+ continue ;
6687+ }
6688+ APInt Val = cast<ConstantSDNode>(Op)->getAPIntValue ();
6689+ ScalarOps.push_back (SignExtendInReg (Val, OpVT));
6690+ }
6691+ return getBuildVector (VT, DL, ScalarOps);
6692+ }
6693+
6694+ if (Ops[0 ].getOpcode () == ISD::SPLAT_VECTOR &&
6695+ isa<ConstantSDNode>(Ops[0 ].getOperand (0 )))
6696+ return getNode (ISD::SPLAT_VECTOR, DL, VT,
6697+ SignExtendInReg (Ops[0 ].getConstantOperandAPInt (0 ),
6698+ Ops[0 ].getOperand (0 ).getValueType ()));
6699+ }
66626700 }
66636701
66646702 // This is for vector folding only from here on.
@@ -7205,41 +7243,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
72057243 " Vector element counts must match in SIGN_EXTEND_INREG" );
72067244 assert (EVT.bitsLE (VT) && " Not extending!" );
72077245 if (EVT == VT) return N1; // Not actually extending
7208-
7209- auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
7210- unsigned FromBits = EVT.getScalarSizeInBits ();
7211- Val <<= Val.getBitWidth () - FromBits;
7212- Val.ashrInPlace (Val.getBitWidth () - FromBits);
7213- return getConstant (Val, DL, ConstantVT);
7214- };
7215-
7216- if (N1C) {
7217- const APInt &Val = N1C->getAPIntValue ();
7218- return SignExtendInReg (Val, VT);
7219- }
7220-
7221- if (ISD::isBuildVectorOfConstantSDNodes (N1.getNode ())) {
7222- SmallVector<SDValue, 8 > Ops;
7223- llvm::EVT OpVT = N1.getOperand (0 ).getValueType ();
7224- for (int i = 0 , e = VT.getVectorNumElements (); i != e; ++i) {
7225- SDValue Op = N1.getOperand (i);
7226- if (Op.isUndef ()) {
7227- Ops.push_back (getUNDEF (OpVT));
7228- continue ;
7229- }
7230- ConstantSDNode *C = cast<ConstantSDNode>(Op);
7231- APInt Val = C->getAPIntValue ();
7232- Ops.push_back (SignExtendInReg (Val, OpVT));
7233- }
7234- return getBuildVector (VT, DL, Ops);
7235- }
7236-
7237- if (N1.getOpcode () == ISD::SPLAT_VECTOR &&
7238- isa<ConstantSDNode>(N1.getOperand (0 )))
7239- return getNode (
7240- ISD::SPLAT_VECTOR, DL, VT,
7241- SignExtendInReg (N1.getConstantOperandAPInt (0 ),
7242- N1.getOperand (0 ).getValueType ()));
72437246 break ;
72447247 }
72457248 case ISD::FP_TO_SINT_SAT:
0 commit comments