Skip to content

Commit 9729885

Browse files
[LLVM][IR] Teach constant integer binop folds about vector ConstantInts. (#115739)
The existing logic mostly works with the main changes being: * Use getScalarSizeInBits instead of IntegerType::getBitWidth * Use ConstantInt::get(Type* instead of ConstantInt::get(LLVMContext
1 parent 8ae2a18 commit 9729885

File tree

10 files changed

+49
-34
lines changed

10 files changed

+49
-34
lines changed

llvm/lib/IR/ConstantFold.cpp

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,26 +231,20 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
231231
return nullptr;
232232
case Instruction::ZExt:
233233
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
234-
uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
235-
return ConstantInt::get(V->getContext(),
236-
CI->getValue().zext(BitWidth));
234+
uint32_t BitWidth = DestTy->getScalarSizeInBits();
235+
return ConstantInt::get(DestTy, CI->getValue().zext(BitWidth));
237236
}
238237
return nullptr;
239238
case Instruction::SExt:
240239
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
241-
uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
242-
return ConstantInt::get(V->getContext(),
243-
CI->getValue().sext(BitWidth));
240+
uint32_t BitWidth = DestTy->getScalarSizeInBits();
241+
return ConstantInt::get(DestTy, CI->getValue().sext(BitWidth));
244242
}
245243
return nullptr;
246244
case Instruction::Trunc: {
247-
if (V->getType()->isVectorTy())
248-
return nullptr;
249-
250-
uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
251245
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
252-
return ConstantInt::get(V->getContext(),
253-
CI->getValue().trunc(DestBitWidth));
246+
uint32_t BitWidth = DestTy->getScalarSizeInBits();
247+
return ConstantInt::get(DestTy, CI->getValue().trunc(BitWidth));
254248
}
255249

256250
return nullptr;
@@ -807,44 +801,44 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
807801
default:
808802
break;
809803
case Instruction::Add:
810-
return ConstantInt::get(CI1->getContext(), C1V + C2V);
804+
return ConstantInt::get(C1->getType(), C1V + C2V);
811805
case Instruction::Sub:
812-
return ConstantInt::get(CI1->getContext(), C1V - C2V);
806+
return ConstantInt::get(C1->getType(), C1V - C2V);
813807
case Instruction::Mul:
814-
return ConstantInt::get(CI1->getContext(), C1V * C2V);
808+
return ConstantInt::get(C1->getType(), C1V * C2V);
815809
case Instruction::UDiv:
816810
assert(!CI2->isZero() && "Div by zero handled above");
817-
return ConstantInt::get(CI1->getContext(), C1V.udiv(C2V));
811+
return ConstantInt::get(CI1->getType(), C1V.udiv(C2V));
818812
case Instruction::SDiv:
819813
assert(!CI2->isZero() && "Div by zero handled above");
820814
if (C2V.isAllOnes() && C1V.isMinSignedValue())
821815
return PoisonValue::get(CI1->getType()); // MIN_INT / -1 -> poison
822-
return ConstantInt::get(CI1->getContext(), C1V.sdiv(C2V));
816+
return ConstantInt::get(CI1->getType(), C1V.sdiv(C2V));
823817
case Instruction::URem:
824818
assert(!CI2->isZero() && "Div by zero handled above");
825-
return ConstantInt::get(CI1->getContext(), C1V.urem(C2V));
819+
return ConstantInt::get(C1->getType(), C1V.urem(C2V));
826820
case Instruction::SRem:
827821
assert(!CI2->isZero() && "Div by zero handled above");
828822
if (C2V.isAllOnes() && C1V.isMinSignedValue())
829-
return PoisonValue::get(CI1->getType()); // MIN_INT % -1 -> poison
830-
return ConstantInt::get(CI1->getContext(), C1V.srem(C2V));
823+
return PoisonValue::get(C1->getType()); // MIN_INT % -1 -> poison
824+
return ConstantInt::get(C1->getType(), C1V.srem(C2V));
831825
case Instruction::And:
832-
return ConstantInt::get(CI1->getContext(), C1V & C2V);
826+
return ConstantInt::get(C1->getType(), C1V & C2V);
833827
case Instruction::Or:
834-
return ConstantInt::get(CI1->getContext(), C1V | C2V);
828+
return ConstantInt::get(C1->getType(), C1V | C2V);
835829
case Instruction::Xor:
836-
return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
830+
return ConstantInt::get(C1->getType(), C1V ^ C2V);
837831
case Instruction::Shl:
838832
if (C2V.ult(C1V.getBitWidth()))
839-
return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
833+
return ConstantInt::get(C1->getType(), C1V.shl(C2V));
840834
return PoisonValue::get(C1->getType()); // too big shift is poison
841835
case Instruction::LShr:
842836
if (C2V.ult(C1V.getBitWidth()))
843-
return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
837+
return ConstantInt::get(C1->getType(), C1V.lshr(C2V));
844838
return PoisonValue::get(C1->getType()); // too big shift is poison
845839
case Instruction::AShr:
846840
if (C2V.ult(C1V.getBitWidth()))
847-
return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
841+
return ConstantInt::get(C1->getType(), C1V.ashr(C2V));
848842
return PoisonValue::get(C1->getType()); // too big shift is poison
849843
}
850844
}
@@ -877,7 +871,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode, Constant *C1,
877871
return ConstantFP::get(C1->getContext(), C3V);
878872
}
879873
}
880-
} else if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
874+
}
875+
876+
if (auto *VTy = dyn_cast<VectorType>(C1->getType())) {
881877
// Fast path for splatted constants.
882878
if (Constant *C2Splat = C2->getSplatValue()) {
883879
if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue())

llvm/lib/IR/Constants.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,13 @@ Constant *Constant::getAggregateElement(unsigned Elt) const {
441441
? CAZ->getElementValue(Elt)
442442
: nullptr;
443443

444+
if (const auto *CI = dyn_cast<ConstantInt>(this))
445+
return Elt < cast<VectorType>(getType())
446+
->getElementCount()
447+
.getKnownMinValue()
448+
? ConstantInt::get(getContext(), CI->getValue())
449+
: nullptr;
450+
444451
// FIXME: getNumElements() will fail for non-fixed vector types.
445452
if (isa<ScalableVectorType>(getType()))
446453
return nullptr;

llvm/test/Transforms/InstCombine/add.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
declare void @use(i8)
56
declare void @use_i1(i1)

llvm/test/Transforms/InstCombine/div.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
declare void @use(i32)
56

llvm/test/Transforms/InstCombine/mul.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
declare i32 @llvm.abs.i32(i32, i1)
56

llvm/test/Transforms/InstCombine/or.ll

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2-
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
2+
; RUN: opt < %s -passes=instcombine -S | FileCheck %s --check-prefixes=CHECK,CONSTVEC
3+
; RUN: opt < %s -passes=instcombine -S -use-constant-int-for-fixed-length-splat | FileCheck %s --check-prefixes=CHECK,CONSTSPLAT
34

45
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128-n32:64"
56
declare void @use(i32)
@@ -399,10 +400,15 @@ define i32 @test30(i32 %A) {
399400
}
400401

401402
define <2 x i32> @test30vec(<2 x i32> %A) {
402-
; CHECK-LABEL: @test30vec(
403-
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
404-
; CHECK-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
405-
; CHECK-NEXT: ret <2 x i32> [[E]]
403+
; CONSTVEC-LABEL: @test30vec(
404+
; CONSTVEC-NEXT: [[TMP1:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
405+
; CONSTVEC-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[TMP1]], splat (i32 32962)
406+
; CONSTVEC-NEXT: ret <2 x i32> [[E]]
407+
;
408+
; CONSTSPLAT-LABEL: @test30vec(
409+
; CONSTSPLAT-NEXT: [[D:%.*]] = and <2 x i32> [[A:%.*]], splat (i32 -58312)
410+
; CONSTSPLAT-NEXT: [[E:%.*]] = or disjoint <2 x i32> [[D]], splat (i32 32962)
411+
; CONSTSPLAT-NEXT: ret <2 x i32> [[E]]
406412
;
407413
%B = or <2 x i32> %A, <i32 32962, i32 32962>
408414
%C = and <2 x i32> %A, <i32 -65536, i32 -65536>

llvm/test/Transforms/InstCombine/rotate.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f32:32:32-f64:32:64-v64:64:64-v128:128:128-a0:0:64-f80:128:128"
56

llvm/test/Transforms/InstCombine/shift.ll

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
34

45
declare void @use(i64)
56
declare void @use_i32(i32)

llvm/test/Transforms/InstCombine/xor-ashr.ll

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
22
; RUN: opt < %s -passes=instcombine -S | FileCheck %s
3+
; RUN: opt < %s -passes=instcombine -use-constant-int-for-fixed-length-splat -S | FileCheck %s
4+
35
target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-a0:0:64-s0:64:64-f80:128:128-n8:16:32:64"
46

57
declare void @use16(i16)

llvm/test/Transforms/InstSimplify/bitcast-vector-fold.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ define <1 x i1> @test10() {
8181
; CONSTVEC-NEXT: ret <1 x i1> [[RET]]
8282
;
8383
; CONSTSPLAT-LABEL: @test10(
84-
; CONSTSPLAT-NEXT: [[RET:%.*]] = icmp eq <1 x i64> splat (i64 -1), zeroinitializer
85-
; CONSTSPLAT-NEXT: ret <1 x i1> [[RET]]
84+
; CONSTSPLAT-NEXT: ret <1 x i1> zeroinitializer
8685
;
8786
%ret = icmp eq <1 x i64> <i64 bitcast (<1 x double> <double 0xFFFFFFFFFFFFFFFF> to i64)>, zeroinitializer
8887
ret <1 x i1> %ret

0 commit comments

Comments
 (0)