diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp index cfe87937c372c..2dbc6785c08b9 100644 --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V, return nullptr; case Instruction::ZExt: if (ConstantInt *CI = dyn_cast(V)) { - uint32_t BitWidth = cast(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().zext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth)); } return nullptr; case Instruction::SExt: if (ConstantInt *CI = dyn_cast(V)) { - uint32_t BitWidth = cast(DestTy)->getBitWidth(); - return ConstantInt::get(V->getContext(), - CI->getValue().sext(BitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth)); } return nullptr; case Instruction::Trunc: { - if (V->getType()->isVectorTy()) - return nullptr; - - uint32_t DestBitWidth = cast(DestTy)->getBitWidth(); if (ConstantInt *CI = dyn_cast(V)) { - return ConstantInt::get(V->getContext(), - CI->getValue().trunc(DestBitWidth)); + uint32_t BitWidth = DestTy->getScalarSizeInBits(); + return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth)); } return nullptr; @@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, default: break; case Instruction::Add: - return ConstantInt::get(CI1->getContext(), C1V + C2V); + return ConstantInt::get(C1->getType(), C1V + C2V); case Instruction::Sub: - return ConstantInt::get(CI1->getContext(), C1V - C2V); + return ConstantInt::get(C1->getType(), C1V - C2V); case Instruction::Mul: - return ConstantInt::get(CI1->getContext(), C1V * C2V); + return ConstantInt::get(C1->getType(), C1V * C2V); case Instruction::UDiv: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.udiv(C2V)); case Instruction::SDiv: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V)); + return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V)); case Instruction::URem: assert(!CI2->isZero() && "Div by zero handled above"); - return ConstantInt::get(CI1->getContext(), C1V.urem(C2V)); + return ConstantInt::get(C1->getType(), C1V.urem(C2V)); case Instruction::SRem: assert(!CI2->isZero() && "Div by zero handled above"); if (C2V.isAllOnes() && C1V.isMinSignedValue()) - return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison - return ConstantInt::get(CI1->getContext(), C1V.srem(C2V)); + return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison + return ConstantInt::get(C1->getType(), C1V.srem(C2V)); case Instruction::And: - return ConstantInt::get(CI1->getContext(), C1V & C2V); + return ConstantInt::get(C1->getType(), C1V & C2V); case Instruction::Or: - return ConstantInt::get(CI1->getContext(), C1V | C2V); + return ConstantInt::get(C1->getType(), C1V | C2V); case Instruction::Xor: - return ConstantInt::get(CI1->getContext(), C1V ^ C2V); + return ConstantInt::get(C1->getType(), C1V ^ C2V); case Instruction::Shl: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.shl(C2V)); + return ConstantInt::get(C1->getType(), C1V.shl(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::LShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V)); + return ConstantInt::get(C1->getType(), C1V.lshr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison case Instruction::AShr: if (C2V.ult(C1V.getBitWidth())) - return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V)); + return ConstantInt::get(C1->getType(), C1V.ashr(C2V)); return PoisonValue::get(C1->getType()); // too big shift is poison } } @@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1, return ConstantFP::get(C1->getContext(), C3V); } } - } else if (auto *VTy = dyn_cast(C1->getType())) { + } + + if (auto *VTy = dyn_cast(C1->getType())) { // Fast path for splatted constants. if (Constant *C2Splat = C2->getSplatValue()) { if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp index 7ae397871bdea..3d6c4ad780dc2 100644 --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -441,6 +441,13 @@ Constant *Constant::getAggregateElement(unsigned Elt) const { ? CAZ->getElementValue(Elt) : nullptr; + if (const auto *CI = dyn_cast(this)) + return Elt < cast(getType()) + ->getElementCount() + .getKnownMinValue() + ? ConstantInt::get(getContext(), CI->getValue()) + : nullptr; + // FIXME: getNumElements() will fail for non-fixed vector types. if (isa(getType())) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/add.ll b/llvm/test/Transforms/InstCombine/add.ll index 4b1159cf07e71..4825e588aa085 100644 --- a/llvm/test/Transforms/InstCombine/add.ll +++ b/llvm/test/Transforms/InstCombine/add.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s declare void @use(i8) declare void @use_i1(i1) diff --git a/llvm/test/Transforms/InstCombine/div.ll b/llvm/test/Transforms/InstCombine/div.ll index 33a8e12dfa1a6..6344966d6cac3 100644 --- a/llvm/test/Transforms/InstCombine/div.ll +++ b/llvm/test/Transforms/InstCombine/div.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s declare void @use(i32) diff --git a/llvm/test/Transforms/InstCombine/mul.ll b/llvm/test/Transforms/InstCombine/mul.ll index e38ab1b9622b2..e3108fc54c4f4 100644 --- a/llvm/test/Transforms/InstCombine/mul.ll +++ b/llvm/test/Transforms/InstCombine/mul.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s declare i32 @llvm.abs.i32(i32, i1) diff --git a/llvm/test/Transforms/InstCombine/or.ll b/llvm/test/Transforms/InstCombine/or.ll index 4a886afd78a5f..95f89e4ce11cd 100644 --- a/llvm/test/Transforms/InstCombine/or.ll +++ b/llvm/test/Transforms/InstCombine/or.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC +; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n32:64" declare void @use(i32) @@ -399,10 +400,15 @@ define i32 @test30(i32 %A) { } define <2 x i32> @test30vec(<2 x i32> %A) { -; CHECK-LABEL: @test30vec( -; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312) -; CHECK-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962) -; CHECK-NEXT: ret <2 x i32> [[E]] +; CONSTVEC-LABEL: @test30vec( +; CONSTVEC-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312) +; CONSTVEC-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962) +; CONSTVEC-NEXT: ret <2 x i32> [[E]] +; +; CONSTSPLAT-LABEL: @test30vec( +; CONSTSPLAT-NEXT: [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312) +; CONSTSPLAT-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962) +; CONSTSPLAT-NEXT: ret <2 x i32> [[E]] ; %B = or <2 x i32> %A, %C = and <2 x i32> %A, diff --git a/llvm/test/Transforms/InstCombine/rotate.ll b/llvm/test/Transforms/InstCombine/rotate.ll index ea7c471594da0..bae50736de0c3 100644 --- a/llvm/test/Transforms/InstCombine/rotate.ll +++ b/llvm/test/Transforms/InstCombine/rotate.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128" diff --git a/llvm/test/Transforms/InstCombine/shift.ll b/llvm/test/Transforms/InstCombine/shift.ll index d2ee97f39123b..d72a1849c7dfd 100644 --- a/llvm/test/Transforms/InstCombine/shift.ll +++ b/llvm/test/Transforms/InstCombine/shift.ll @@ -1,5 +1,6 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s declare void @use(i64) declare void @use_i32(i32) diff --git a/llvm/test/Transforms/InstCombine/xor-ashr.ll b/llvm/test/Transforms/InstCombine/xor-ashr.ll index 0c0554adcf123..f5ccdeef2f382 100644 --- a/llvm/test/Transforms/InstCombine/xor-ashr.ll +++ b/llvm/test/Transforms/InstCombine/xor-ashr.ll @@ -1,5 +1,7 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py ; RUN: opt < %s -passes=instcombine -S | FileCheck %s +; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s + target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64" declare void @use16(i16) diff --git a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll index 3f1672d66abf0..b475b8199541d 100644 --- a/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll +++ b/llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll @@ -81,8 +81,7 @@ define <1 x i1> @test10() { ; CONSTVEC-NEXT: ret <1 x i1> [[RET]] ; ; CONSTSPLAT-LABEL: @test10( -; CONSTSPLAT-NEXT: [[RET:%.*]] = icmp eq <1 x i64> splat (i64 -1), zeroinitializer -; CONSTSPLAT-NEXT: ret <1 x i1> [[RET]] +; CONSTSPLAT-NEXT: ret <1 x i1> zeroinitializer ; %ret = icmp eq <1 x i64> to i64)>, zeroinitializer ret <1 x i1> %ret