Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3119,6 +3119,127 @@ static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) {
return nullptr;
}

static Value *combineOrOfImmCmpToBitExtract(Instruction &Or,
InstCombiner::BuilderTy &Builder,
const DataLayout &DL) {

auto isICmpEqImm = [](Value *N, ConstantInt *&Imm, Value *&X) -> bool {
if (X)
return match(N, m_OneUse(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Specific(X),
m_ConstantInt(Imm))));

return match(N, m_OneUse(m_SpecificICmp(ICmpInst::ICMP_EQ, m_Value(X),
m_ConstantInt(Imm))));
};

// %srl = lshr %bitmap, %X
// %icmp = icmp ult %X, %max_value
// %trunc = trunc %srl to i1
// %sel = select %icmp, %trunc, false
auto CreateBitExtractSeq = [&](APInt BitMap, APInt MaxValue,
Value *X) -> Value * {
LLVMContext &Context = Or.getContext();

// %srl = lshr %bitmap, %X
// It is okay for the shift amount to be truncated because
// if information is lost then it is garunteed to fail the bounds
// check and the shift result will be discarded
ConstantInt *BitMapConst = ConstantInt::get(Context, BitMap);
Value *ShiftAmt =
Builder.CreateZExtOrTrunc(X, BitMapConst->getIntegerType());
Value *LShr = Builder.CreateLShr(BitMapConst, ShiftAmt);

// %icmp = icmp ult %X, %max_value
// Use the type that is the larger of 'X' and the bounds integer
// so that no information is lost
Value *MaxVal = ConstantInt::get(Context, MaxValue);
if (MaxVal->getType()->getIntegerBitWidth() >
X->getType()->getIntegerBitWidth())
X = Builder.CreateZExt(X, MaxVal->getType());
else
MaxVal = Builder.CreateZExt(MaxVal, X->getType());
Value *BoundsCheck = Builder.CreateICmp(ICmpInst::ICMP_ULT, X, MaxVal);

// %trunc = trunc %srl to i1
// Only care about the low bit
Value *ShrTrunc = Builder.CreateTrunc(LShr, IntegerType::get(Context, 1));

// %sel = select %icmp, %trunc, false
return Builder.CreateSelect(BoundsCheck, ShrTrunc,
ConstantInt::getFalse(Context));
};

// Our BitMap should be able to fit into a single arch register
// otherwise the tranformation won't be profitable
unsigned XLen = DL.getLargestLegalIntTypeSizeInBits();
auto validImm = [&](APInt APImm) -> bool {
auto Imm = APImm.tryZExtValue();
return Imm && (*Imm < XLen);
};

// Match (or (icmp eq X, Imm0), (icmp eq X, Imm1))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The transform with just two options is about a net neutral (at least going to the RISCV backend). However, the benefit comes once we can expand to combining 3 or more comparisons. From a code standpoint, the implementation seemed cleaner to operate on a single or at a time though.

ConstantInt *LHS, *RHS;
Value *X = nullptr;
if (isICmpEqImm(Or.getOperand(0), LHS, X) &&
isICmpEqImm(Or.getOperand(1), RHS, X)) {
// The Shr with become poison when shifted by Undef
if (!isGuaranteedNotToBeUndefOrPoison(X))
return nullptr;

APInt LHSAP = LHS->getValue();
APInt RHSAP = RHS->getValue();
if (!validImm(LHSAP) || !validImm(RHSAP))
return nullptr;
LHSAP = LHSAP.zextOrTrunc(XLen);
RHSAP = RHSAP.zextOrTrunc(XLen);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is the best way to wrangle various APInt's that may have differing BitWidth's- effectively I ensure they can be represented within XLen and then force them all to have a BitWidth or XLen.


// Create the BitMap and Bounds check immediates
// +1 to bound becuase strictly less than
APInt BitMap = (APInt(XLen, 1) << LHSAP) | (APInt(XLen, 1) << RHSAP);
APInt Bound = RHSAP.ugt(LHSAP) ? RHSAP : LHSAP;
return CreateBitExtractSeq(BitMap, Bound + 1, X);
}

// Expand an already existing BitMap sequence
// Match: (or (%BitMapSeq(X)), (icmp eq X, Imm))
ConstantInt *BitMap, *Bound, *CmpImm;
Value *Cmp;
if (match(&Or, m_OneUse(m_c_Or(m_Value(Cmp),
m_OneUse(m_Select(
m_SpecificICmp(ICmpInst::ICMP_ULT,
m_ZExtOrSelf(m_Value(X)),
m_ConstantInt(Bound)),
m_OneUse(m_Trunc(m_OneUse(m_Shr(
m_ConstantInt(BitMap),
m_ZExtOrTruncOrSelf(m_Deferred(X)))))),
m_Zero()))))) &&
isICmpEqImm(Cmp, CmpImm, X)) {
if (!isGuaranteedNotToBeUndefOrPoison(X))
return nullptr;

APInt NewAP = CmpImm->getValue();
APInt BitMapAP = BitMap->getValue();
APInt BoundAP = Bound->getValue().zextOrTrunc(XLen);
// BitMap must fit in native arch register
if (!validImm(NewAP) || !DL.fitsInLegalInteger(BitMapAP.getActiveBits()))
return nullptr;

NewAP = NewAP.zextOrTrunc(XLen);
BitMapAP = BitMapAP.zextOrTrunc(XLen);

// Bounding immediate must be greater than the largest bit in the BitMap
// and less then XLen
if (BoundAP.ult(BitMapAP.getActiveBits()) || BoundAP.ugt(XLen))
return nullptr;

if (NewAP.uge(BoundAP))
BoundAP = NewAP + 1;
BitMapAP |= (APInt(XLen, 1) << NewAP);
return CreateBitExtractSeq(BitMapAP, BoundAP, X);
}
return nullptr;
}

/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns.
static Value *matchOrConcat(Instruction &Or, InstCombiner::BuilderTy &Builder) {
assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'");
Expand Down Expand Up @@ -4084,6 +4205,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
if (Instruction *Funnel = matchFunnelShift(I, *this))
return Funnel;

if (Value *BitExtract =
combineOrOfImmCmpToBitExtract(I, Builder, getDataLayout()))
return replaceInstUsesWith(I, BitExtract);

if (Value *Concat = matchOrConcat(I, Builder))
return replaceInstUsesWith(I, Concat);

Expand Down
Loading
Loading