-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[InstCombine][VectorCombine][NFC] Unify uses of lossless inverse cast #156597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
| if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { | ||
| if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) { | ||
| if (Constant *TruncC = | ||
| getLosslessInvCast(C, SrcTy, Instruction::ZExt, DL)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also remove the old helper functions, which are presumably unused now?
TBH I find the new API super confusing. This is doing a truncate, but it's phrased in terms of a ZExt.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also remove the old helper functions, which are presumably unused now?
Oh sure. I missed it.
TBH I find the new API super confusing. This is doing a truncate, but it's phrased in terms of a ZExt.
It's doing the inverse cast of ZExt, to find a constant InvC, s.t. ZExt(InvC) == C. Or we still extract the ZExt/SExt part and name it getLosslessTrunc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its messy, but we could keep the getLosslessUnsignedTrunc/getLosslessSignedTrunc function names and change them to wrappers for getLosslessInvCast?
| bool NSW = false; | ||
| }; | ||
|
|
||
| /// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CastOp(InvC) == C is imprecise (e.g., bitcast <2 x i16> <i16 0, i16 poison> to i32 is refined to 0).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. What about CastOp(InvC) == C or CastOp(InvC) is one of the possible values of C if C is undefined
|
@llvm/pr-subscribers-llvm-analysis @llvm/pr-subscribers-llvm-transforms Author: Hongyu Chen (XChy) ChangesThis patch addresses #155216 (comment). Full diff: https://github.com/llvm/llvm-project/pull/156597.diff 11 Files Affected:
diff --git a/llvm/include/llvm/Analysis/ConstantFolding.h b/llvm/include/llvm/Analysis/ConstantFolding.h
index dcbac8a301025..5f91f9747bb97 100644
--- a/llvm/include/llvm/Analysis/ConstantFolding.h
+++ b/llvm/include/llvm/Analysis/ConstantFolding.h
@@ -226,6 +226,27 @@ LLVM_ABI bool isMathLibCallNoop(const CallBase *Call,
LLVM_ABI Constant *ReadByteArrayFromGlobal(const GlobalVariable *GV,
uint64_t Offset);
-}
+
+struct PreservedCastFlags {
+ bool NNeg = false;
+ bool NUW = false;
+ bool NSW = false;
+};
+
+/// Try to cast C to InvC losslessly, satisfying CastOp(InvC) equals C, or
+/// CastOp(InvC) is a refined value of undefined C. Will try best to
+/// preserve the flags.
+LLVM_ABI Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
+ unsigned CastOp, const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+
+LLVM_ABI Constant *
+getLosslessUnsignedTrunc(Constant *C, Type *DestTy, const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+
+LLVM_ABI Constant *getLosslessSignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags = nullptr);
+} // namespace llvm
#endif
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index 2148431c1acce..40e176c2ab5ce 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -4608,4 +4608,55 @@ bool llvm::isMathLibCallNoop(const CallBase *Call,
return false;
}
+Constant *llvm::getLosslessInvCast(Constant *C, Type *InvCastTo,
+ unsigned CastOp, const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ switch (CastOp) {
+ case Instruction::BitCast:
+ // Bitcast is always lossless.
+ return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
+ case Instruction::Trunc: {
+ auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
+ if (Flags) {
+ // Truncation back on ZExt value is always NUW.
+ Flags->NUW = true;
+ // Test positivity of C.
+ auto *SExtC =
+ ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
+ Flags->NSW = ZExtC == SExtC;
+ }
+ return ZExtC;
+ }
+ case Instruction::SExt:
+ case Instruction::ZExt: {
+ auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
+ auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
+ // Must satisfy CastOp(InvC) == C.
+ if (!CastInvC || CastInvC != C)
+ return nullptr;
+ if (Flags && CastOp == Instruction::ZExt) {
+ auto *SExtInvC =
+ ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
+ // Test positivity of InvC.
+ Flags->NNeg = CastInvC == SExtInvC;
+ }
+ return InvC;
+ }
+ default:
+ return nullptr;
+ }
+}
+
+Constant *llvm::getLosslessUnsignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ return getLosslessInvCast(C, DestTy, Instruction::ZExt, DL, Flags);
+}
+
+Constant *llvm::getLosslessSignedTrunc(Constant *C, Type *DestTy,
+ const DataLayout &DL,
+ PreservedCastFlags *Flags) {
+ return getLosslessInvCast(C, DestTy, Instruction::SExt, DL, Flags);
+}
+
void TargetFolder::anchor() {}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index a13d3ceb61320..8b9df62d7c652 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1799,8 +1799,9 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
// type may provide more information to later folds, and the smaller logic
// instruction may be cheaper (particularly in the case of vectors).
Value *X;
+ auto &DL = IC.getDataLayout();
if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) {
- if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) {
+ if (Constant *TruncC = getLosslessUnsignedTrunc(C, SrcTy, DL)) {
// LogicOpc (zext X), C --> zext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new ZExtInst(NewOp, DestTy);
@@ -1808,7 +1809,7 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast,
}
if (match(Cast, m_OneUse(m_SExtLike(m_Value(X))))) {
- if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) {
+ if (Constant *TruncC = getLosslessSignedTrunc(C, SrcTy, DL)) {
// LogicOpc (sext X), C --> sext (LogicOpc X, C)
Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC);
return new SExtInst(NewOp, DestTy);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 42b65dde67255..33b66aeaffe60 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1956,7 +1956,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) {
+ if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType(), DL)) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType());
}
@@ -2006,7 +2006,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Constant *C;
if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) &&
I0->hasOneUse()) {
- if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) {
+ if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType(), DL)) {
Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC);
return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType());
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 90feddf6dcfe1..861630680752f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -6375,7 +6375,7 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) {
// If a lossless truncate is possible...
Type *SrcTy = CastOp0->getSrcTy();
- Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode());
+ Constant *Res = getLosslessInvCast(C, SrcTy, CastOp0->getOpcode(), DL);
if (Res) {
if (ICmp.isEquality())
return new ICmpInst(ICmp.getPredicate(), X, Res);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 2340028ce93dc..d3d23130b6fc4 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -222,23 +222,6 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF,
const Instruction *CtxI) const;
- Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) {
- Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy);
- Constant *ExtTruncC =
- ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL);
- if (ExtTruncC && ExtTruncC == C)
- return TruncC;
- return nullptr;
- }
-
- Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) {
- return getLosslessTrunc(C, TruncTy, Instruction::ZExt);
- }
-
- Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) {
- return getLosslessTrunc(C, TruncTy, Instruction::SExt);
- }
-
std::optional<std::pair<Intrinsic::ID, SmallVector<Value *, 3>>>
convertOrOfShiftsToFunnelShift(Instruction &Or);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index d7310b1c741c0..a9aacc707cc20 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -1642,10 +1642,11 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
}
Constant *C;
+ auto &DL = IC.getDataLayout();
if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
match(D, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
if (!TruncC)
return nullptr;
@@ -1656,7 +1657,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
match(N, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
- Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType());
+ Constant *TruncC = getLosslessUnsignedTrunc(C, X->getType(), DL);
if (!TruncC)
return nullptr;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 6477141ab095f..ed9a0be6981fa 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -841,7 +841,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) {
NumZexts++;
} else if (auto *C = dyn_cast<Constant>(V)) {
// Make sure that constants can fit in the new type.
- Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType);
+ Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType, DL);
if (!Trunc)
return nullptr;
NewIncoming.push_back(Trunc);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ba8b4c47e8f88..9467463d39c0e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -2375,7 +2375,7 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) {
// If the constant is the same after truncation to the smaller type and
// extension to the original type, we can narrow the select.
Type *SelType = Sel.getType();
- Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode);
+ Constant *TruncC = getLosslessInvCast(C, SmallType, ExtOpcode, DL);
if (TruncC && ExtInst->hasOneUse()) {
Value *TruncCVal = cast<Value>(TruncC);
if (ExtInst == Sel.getFalseValue())
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 1a9b54bc009bc..4960a50bbede8 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2568,7 +2568,7 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) {
Constant *WideC;
if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC)))
return nullptr;
- Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc);
+ Constant *NarrowC = getLosslessInvCast(WideC, X->getType(), CastOpc, DL);
if (!NarrowC)
return nullptr;
Y = NarrowC;
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index 6e46547b15b2b..93b3a0eeb0305 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -938,51 +938,6 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
return true;
}
-struct PreservedCastFlags {
- bool NNeg = false;
- bool NUW = false;
- bool NSW = false;
-};
-
-// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C.
-// Will try best to preserve the flags.
-static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
- Instruction::CastOps CastOp,
- const DataLayout &DL,
- PreservedCastFlags &Flags) {
- switch (CastOp) {
- case Instruction::BitCast:
- // Bitcast is always lossless.
- return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
- case Instruction::Trunc: {
- auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
- auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
- // Truncation back on ZExt value is always NUW.
- Flags.NUW = true;
- // Test positivity of C.
- Flags.NSW = ZExtC == SExtC;
- return ZExtC;
- }
- case Instruction::SExt:
- case Instruction::ZExt: {
- auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
- auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
- // Must satisfy CastOp(InvC) == C.
- if (!CastInvC || CastInvC != C)
- return nullptr;
- if (CastOp == Instruction::ZExt) {
- auto *SExtInvC =
- ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
- // Test positivity of InvC.
- Flags.NNeg = CastInvC == SExtInvC;
- }
- return InvC;
- }
- default:
- return nullptr;
- }
-}
-
/// Match:
// bitop(castop(x), C) ->
// bitop(castop(x), castop(InvC)) ->
@@ -1025,7 +980,7 @@ bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
// Find the constant InvC, such that castop(InvC) equals to C.
PreservedCastFlags RHSFlags;
- Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, RHSFlags);
+ Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, &RHSFlags);
if (!InvC)
return false;
|
|
Gently ping. |
nikic
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/11/builds/23497 Here is the relevant piece of the build log for the reference |
This patch addresses #155216 (comment).
This patch adds a helper function to put the inverse cast on constants, with cast flags preserved(optional).
Follow-up patches will add trunc/ext handling on VectorCombine and flags preservation on InstCombine.