@@ -6659,6 +6659,44 @@ SDValue SelectionDAG::FoldConstantArithmetic(unsigned Opcode, const SDLoc &DL,
6659
6659
if (TLI->isCommutativeBinOp (Opcode))
6660
6660
if (GlobalAddressSDNode *GA = dyn_cast<GlobalAddressSDNode>(Ops[1 ]))
6661
6661
return FoldSymbolOffset (Opcode, VT, GA, Ops[0 ].getNode ());
6662
+
6663
+ // fold (sext_in_reg c1) -> c2
6664
+ if (Opcode == ISD::SIGN_EXTEND_INREG) {
6665
+ EVT EVT = cast<VTSDNode>(Ops[1 ])->getVT ();
6666
+
6667
+ auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
6668
+ unsigned FromBits = EVT.getScalarSizeInBits ();
6669
+ Val <<= Val.getBitWidth () - FromBits;
6670
+ Val.ashrInPlace (Val.getBitWidth () - FromBits);
6671
+ return getConstant (Val, DL, ConstantVT);
6672
+ };
6673
+
6674
+ if (auto *C1 = dyn_cast<ConstantSDNode>(Ops[0 ])) {
6675
+ const APInt &Val = C1->getAPIntValue ();
6676
+ return SignExtendInReg (Val, VT);
6677
+ }
6678
+
6679
+ if (ISD::isBuildVectorOfConstantSDNodes (Ops[0 ].getNode ())) {
6680
+ SmallVector<SDValue, 8 > ScalarOps;
6681
+ llvm::EVT OpVT = Ops[0 ].getOperand (0 ).getValueType ();
6682
+ for (int I = 0 , E = VT.getVectorNumElements (); I != E; ++I) {
6683
+ SDValue Op = Ops[0 ].getOperand (I);
6684
+ if (Op.isUndef ()) {
6685
+ ScalarOps.push_back (getUNDEF (OpVT));
6686
+ continue ;
6687
+ }
6688
+ APInt Val = cast<ConstantSDNode>(Op)->getAPIntValue ();
6689
+ ScalarOps.push_back (SignExtendInReg (Val, OpVT));
6690
+ }
6691
+ return getBuildVector (VT, DL, ScalarOps);
6692
+ }
6693
+
6694
+ if (Ops[0 ].getOpcode () == ISD::SPLAT_VECTOR &&
6695
+ isa<ConstantSDNode>(Ops[0 ].getOperand (0 )))
6696
+ return getNode (ISD::SPLAT_VECTOR, DL, VT,
6697
+ SignExtendInReg (Ops[0 ].getConstantOperandAPInt (0 ),
6698
+ Ops[0 ].getOperand (0 ).getValueType ()));
6699
+ }
6662
6700
}
6663
6701
6664
6702
// This is for vector folding only from here on.
@@ -7205,41 +7243,6 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
7205
7243
" Vector element counts must match in SIGN_EXTEND_INREG" );
7206
7244
assert (EVT.bitsLE (VT) && " Not extending!" );
7207
7245
if (EVT == VT) return N1; // Not actually extending
7208
-
7209
- auto SignExtendInReg = [&](APInt Val, llvm::EVT ConstantVT) {
7210
- unsigned FromBits = EVT.getScalarSizeInBits ();
7211
- Val <<= Val.getBitWidth () - FromBits;
7212
- Val.ashrInPlace (Val.getBitWidth () - FromBits);
7213
- return getConstant (Val, DL, ConstantVT);
7214
- };
7215
-
7216
- if (N1C) {
7217
- const APInt &Val = N1C->getAPIntValue ();
7218
- return SignExtendInReg (Val, VT);
7219
- }
7220
-
7221
- if (ISD::isBuildVectorOfConstantSDNodes (N1.getNode ())) {
7222
- SmallVector<SDValue, 8 > Ops;
7223
- llvm::EVT OpVT = N1.getOperand (0 ).getValueType ();
7224
- for (int i = 0 , e = VT.getVectorNumElements (); i != e; ++i) {
7225
- SDValue Op = N1.getOperand (i);
7226
- if (Op.isUndef ()) {
7227
- Ops.push_back (getUNDEF (OpVT));
7228
- continue ;
7229
- }
7230
- ConstantSDNode *C = cast<ConstantSDNode>(Op);
7231
- APInt Val = C->getAPIntValue ();
7232
- Ops.push_back (SignExtendInReg (Val, OpVT));
7233
- }
7234
- return getBuildVector (VT, DL, Ops);
7235
- }
7236
-
7237
- if (N1.getOpcode () == ISD::SPLAT_VECTOR &&
7238
- isa<ConstantSDNode>(N1.getOperand (0 )))
7239
- return getNode (
7240
- ISD::SPLAT_VECTOR, DL, VT,
7241
- SignExtendInReg (N1.getConstantOperandAPInt (0 ),
7242
- N1.getOperand (0 ).getValueType ()));
7243
7246
break ;
7244
7247
}
7245
7248
case ISD::FP_TO_SINT_SAT:
0 commit comments