Skip to content

Commit a9c112e

Browse files
committed
[InstCombine] Fold integer unpack/repack patterns through ZExt (llvm#153583)
This patch explicitly enables the InstCombiner to fold integer unpack/repack patterns such as ```llvm define i64 @src_combine(i32 %lower, i32 %upper) { %base = zext i32 %lower to i64 %u.0 = and i32 %upper, u0xff %z.0 = zext i32 %u.0 to i64 %s.0 = shl i64 %z.0, 32 %o.0 = or i64 %base, %s.0 %r.1 = lshr i32 %upper, 8 %u.1 = and i32 %r.1, u0xff %z.1 = zext i32 %u.1 to i64 %s.1 = shl i64 %z.1, 40 %o.1 = or i64 %o.0, %s.1 %r.2 = lshr i32 %upper, 16 %u.2 = and i32 %r.2, u0xff %z.2 = zext i32 %u.2 to i64 %s.2 = shl i64 %z.2, 48 %o.2 = or i64 %o.1, %s.2 %r.3 = lshr i32 %upper, 24 %u.3 = and i32 %r.3, u0xff %z.3 = zext i32 %u.3 to i64 %s.3 = shl i64 %z.3, 56 %o.3 = or i64 %o.2, %s.3 ret i64 %o.3 } ; => define i64 @tgt_combine(i32 %lower, i32 %upper) { %base = zext i32 %lower to i64 %upper.zext = zext i32 %upper to i64 %s.0 = shl nuw i64 %upper.zext, 32 %o.3 = or disjoint i64 %s.0, %base ret i64 %o.3 } ``` Alive2 proofs: [YAy7ny](https://alive2.llvm.org/ce/z/YAy7ny)
1 parent a6597e7 commit a9c112e

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3800,6 +3800,109 @@ static Instruction *foldIntegerPackFromVector(Instruction &I,
38003800
return CastInst::Create(Instruction::BitCast, MaskedVec, I.getType());
38013801
}
38023802

3803+
/// Match \p V as "lshr -> mask -> zext -> shl".
3804+
///
3805+
/// \p Int is the underlying integer being extracted from.
3806+
/// \p Mask is a bitmask identifying which bits of the integer are being
3807+
/// extracted. \p Offset identifies which bit of the result \p V corresponds to
3808+
/// the least significant bit of \p Int
3809+
static bool matchZExtedSubInteger(Value *V, Value *&Int, APInt &Mask,
3810+
uint64_t &Offset, bool &IsShlNUW,
3811+
bool &IsShlNSW) {
3812+
Value *ShlOp0;
3813+
uint64_t ShlAmt = 0;
3814+
if (!match(V, m_OneUse(m_Shl(m_Value(ShlOp0), m_ConstantInt(ShlAmt)))))
3815+
return false;
3816+
3817+
IsShlNUW = cast<BinaryOperator>(V)->hasNoUnsignedWrap();
3818+
IsShlNSW = cast<BinaryOperator>(V)->hasNoSignedWrap();
3819+
3820+
Value *ZExtOp0;
3821+
if (!match(ShlOp0, m_OneUse(m_ZExt(m_Value(ZExtOp0)))))
3822+
return false;
3823+
3824+
Value *MaskedOp0;
3825+
const APInt *ShiftedMaskConst = nullptr;
3826+
if (!match(ZExtOp0, m_CombineOr(m_OneUse(m_And(m_Value(MaskedOp0),
3827+
m_APInt(ShiftedMaskConst))),
3828+
m_Value(MaskedOp0))))
3829+
return false;
3830+
3831+
uint64_t LShrAmt = 0;
3832+
if (!match(MaskedOp0,
3833+
m_CombineOr(m_OneUse(m_LShr(m_Value(Int), m_ConstantInt(LShrAmt))),
3834+
m_Value(Int))))
3835+
return false;
3836+
3837+
if (LShrAmt > ShlAmt)
3838+
return false;
3839+
Offset = ShlAmt - LShrAmt;
3840+
3841+
Mask = ShiftedMaskConst ? ShiftedMaskConst->shl(LShrAmt)
3842+
: APInt::getBitsSetFrom(
3843+
Int->getType()->getScalarSizeInBits(), LShrAmt);
3844+
3845+
return true;
3846+
}
3847+
3848+
/// Try to fold the join of two scalar integers whose bits are unpacked and
3849+
/// zexted from the same source integer.
3850+
static Value *foldIntegerRepackThroughZExt(Value *Lhs, Value *Rhs,
3851+
InstCombiner::BuilderTy &Builder) {
3852+
3853+
Value *LhsInt, *RhsInt;
3854+
APInt LhsMask, RhsMask;
3855+
uint64_t LhsOffset, RhsOffset;
3856+
bool IsLhsShlNUW, IsLhsShlNSW, IsRhsShlNUW, IsRhsShlNSW;
3857+
if (!matchZExtedSubInteger(Lhs, LhsInt, LhsMask, LhsOffset, IsLhsShlNUW,
3858+
IsLhsShlNSW))
3859+
return nullptr;
3860+
if (!matchZExtedSubInteger(Rhs, RhsInt, RhsMask, RhsOffset, IsRhsShlNUW,
3861+
IsRhsShlNSW))
3862+
return nullptr;
3863+
if (LhsInt != RhsInt || LhsOffset != RhsOffset)
3864+
return nullptr;
3865+
3866+
APInt Mask = LhsMask | RhsMask;
3867+
3868+
Type *DestTy = Lhs->getType();
3869+
Value *Res = Builder.CreateShl(
3870+
Builder.CreateZExt(
3871+
Builder.CreateAnd(LhsInt, Mask, LhsInt->getName() + ".mask"), DestTy,
3872+
LhsInt->getName() + ".zext"),
3873+
ConstantInt::get(DestTy, LhsOffset), "", IsLhsShlNUW && IsRhsShlNUW,
3874+
IsLhsShlNSW && IsRhsShlNSW);
3875+
Res->takeName(Lhs);
3876+
return Res;
3877+
}
3878+
3879+
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
3880+
if (Value *Res = foldIntegerRepackThroughZExt(LHS, RHS, Builder))
3881+
return Res;
3882+
3883+
return nullptr;
3884+
}
3885+
3886+
Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {
3887+
3888+
Value *X, *Y;
3889+
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3890+
if (Value *Res = foldDisjointOr(LHS, X))
3891+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3892+
if (Value *Res = foldDisjointOr(LHS, Y))
3893+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3894+
}
3895+
3896+
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3897+
if (Value *Res = foldDisjointOr(X, RHS))
3898+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3899+
if (Value *Res = foldDisjointOr(Y, RHS))
3900+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3901+
}
3902+
3903+
return nullptr;
3904+
}
3905+
38033906
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
38043907
// here. We should standardize that construct where it is needed or choose some
38053908
// other way to ensure that commutated variants of patterns are not missed.

0 commit comments

Comments
 (0)