@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
231231 return nullptr ;
232232 case Instruction::ZExt:
233233 if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
234- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth ();
235- return ConstantInt::get (V->getContext (),
236- CI->getValue ().zext (BitWidth));
234+ uint32_t BitWidth = DestTy->getScalarSizeInBits ();
235+ return ConstantInt::get (DestTy, CI->getValue ().zext (BitWidth));
237236 }
238237 return nullptr ;
239238 case Instruction::SExt:
240239 if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
241- uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth ();
242- return ConstantInt::get (V->getContext (),
243- CI->getValue ().sext (BitWidth));
240+ uint32_t BitWidth = DestTy->getScalarSizeInBits ();
241+ return ConstantInt::get (DestTy, CI->getValue ().sext (BitWidth));
244242 }
245243 return nullptr ;
246244 case Instruction::Trunc: {
247- if (V->getType ()->isVectorTy ())
248- return nullptr ;
249-
250- uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth ();
251245 if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
252- return ConstantInt::get (V-> getContext (),
253- CI->getValue ().trunc (DestBitWidth ));
246+ uint32_t BitWidth = DestTy-> getScalarSizeInBits ();
247+ return ConstantInt::get (DestTy, CI->getValue ().trunc (BitWidth ));
254248 }
255249
256250 return nullptr ;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
807801 default :
808802 break ;
809803 case Instruction::Add:
810- return ConstantInt::get (CI1-> getContext (), C1V + C2V);
804+ return ConstantInt::get (C1-> getType (), C1V + C2V);
811805 case Instruction::Sub:
812- return ConstantInt::get (CI1-> getContext (), C1V - C2V);
806+ return ConstantInt::get (C1-> getType (), C1V - C2V);
813807 case Instruction::Mul:
814- return ConstantInt::get (CI1-> getContext (), C1V * C2V);
808+ return ConstantInt::get (C1-> getType (), C1V * C2V);
815809 case Instruction::UDiv:
816810 assert (!CI2->isZero () && " Div by zero handled above" );
817- return ConstantInt::get (CI1->getContext (), C1V.udiv (C2V));
811+ return ConstantInt::get (CI1->getType (), C1V.udiv (C2V));
818812 case Instruction::SDiv:
819813 assert (!CI2->isZero () && " Div by zero handled above" );
820814 if (C2V.isAllOnes () && C1V.isMinSignedValue ())
821815 return PoisonValue::get (CI1->getType ()); // MIN_INT / -1 -> poison
822- return ConstantInt::get (CI1->getContext (), C1V.sdiv (C2V));
816+ return ConstantInt::get (CI1->getType (), C1V.sdiv (C2V));
823817 case Instruction::URem:
824818 assert (!CI2->isZero () && " Div by zero handled above" );
825- return ConstantInt::get (CI1-> getContext (), C1V.urem (C2V));
819+ return ConstantInt::get (C1-> getType (), C1V.urem (C2V));
826820 case Instruction::SRem:
827821 assert (!CI2->isZero () && " Div by zero handled above" );
828822 if (C2V.isAllOnes () && C1V.isMinSignedValue ())
829- return PoisonValue::get (CI1 ->getType ()); // MIN_INT % -1 -> poison
830- return ConstantInt::get (CI1-> getContext (), C1V.srem (C2V));
823+ return PoisonValue::get (C1 ->getType ()); // MIN_INT % -1 -> poison
824+ return ConstantInt::get (C1-> getType (), C1V.srem (C2V));
831825 case Instruction::And:
832- return ConstantInt::get (CI1-> getContext (), C1V & C2V);
826+ return ConstantInt::get (C1-> getType (), C1V & C2V);
833827 case Instruction::Or:
834- return ConstantInt::get (CI1-> getContext (), C1V | C2V);
828+ return ConstantInt::get (C1-> getType (), C1V | C2V);
835829 case Instruction::Xor:
836- return ConstantInt::get (CI1-> getContext (), C1V ^ C2V);
830+ return ConstantInt::get (C1-> getType (), C1V ^ C2V);
837831 case Instruction::Shl:
838832 if (C2V.ult (C1V.getBitWidth ()))
839- return ConstantInt::get (CI1-> getContext (), C1V.shl (C2V));
833+ return ConstantInt::get (C1-> getType (), C1V.shl (C2V));
840834 return PoisonValue::get (C1->getType ()); // too big shift is poison
841835 case Instruction::LShr:
842836 if (C2V.ult (C1V.getBitWidth ()))
843- return ConstantInt::get (CI1-> getContext (), C1V.lshr (C2V));
837+ return ConstantInt::get (C1-> getType (), C1V.lshr (C2V));
844838 return PoisonValue::get (C1->getType ()); // too big shift is poison
845839 case Instruction::AShr:
846840 if (C2V.ult (C1V.getBitWidth ()))
847- return ConstantInt::get (CI1-> getContext (), C1V.ashr (C2V));
841+ return ConstantInt::get (C1-> getType (), C1V.ashr (C2V));
848842 return PoisonValue::get (C1->getType ()); // too big shift is poison
849843 }
850844 }
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
877871 return ConstantFP::get (C1->getContext (), C3V);
878872 }
879873 }
880- } else if (auto *VTy = dyn_cast<VectorType>(C1->getType ())) {
874+ }
875+
876+ if (auto *VTy = dyn_cast<VectorType>(C1->getType ())) {
881877 // Fast path for splatted constants.
882878 if (Constant *C2Splat = C2->getSplatValue ()) {
883879 if (Instruction::isIntDivRem (Opcode) && C2Splat->isNullValue ())
0 commit comments