Skip to content

Commit 3d5d32c

Browse files
authored
[CGP]: Optimize mul.overflow. (#148343)
- Detect cases where LHS & RHS values will not cause overflow (when the Hi halfs are zero).
1 parent 52f4c36 commit 3d5d32c

File tree

8 files changed

+699
-144
lines changed

8 files changed

+699
-144
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3492,6 +3492,13 @@ class LLVM_ABI TargetLoweringBase {
34923492
return MathUsed && (VT.isSimple() || !isOperationExpand(Opcode, VT));
34933493
}
34943494

3495+
// Return true if the target wants to optimize the mul overflow intrinsic
3496+
// for the given \p VT.
3497+
virtual bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
3498+
EVT VT) const {
3499+
return false;
3500+
}
3501+
34953502
// Return true if it is profitable to use a scalar input to a BUILD_VECTOR
34963503
// even if the vector itself has multiple uses.
34973504
virtual bool aggressivelyPreferBuildVectorSources(EVT VecVT) const {

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,8 @@ class CodeGenPrepare {
431431
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
432432
unsigned AddrSpace);
433433
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
434+
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
435+
ModifyDT &ModifiedDT);
434436
bool optimizeInlineAsmInst(CallInst *CS);
435437
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
436438
bool optimizeExt(Instruction *&I);
@@ -2797,6 +2799,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
27972799
}
27982800
}
27992801
return false;
2802+
case Intrinsic::umul_with_overflow:
2803+
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
2804+
case Intrinsic::smul_with_overflow:
2805+
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
28002806
}
28012807

28022808
SmallVector<Value *, 2> PtrOps;
@@ -6391,6 +6397,182 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
63916397
return true;
63926398
}
63936399

6400+
// This is a helper for CodeGenPrepare::optimizeMulWithOverflow.
6401+
// Check the pattern we are interested in where there are maximum 2 uses
6402+
// of the intrinsic which are the extract instructions.
6403+
static bool matchOverflowPattern(Instruction *&I, ExtractValueInst *&MulExtract,
6404+
ExtractValueInst *&OverflowExtract) {
6405+
// Bail out if it's more than 2 users:
6406+
if (I->hasNUsesOrMore(3))
6407+
return false;
6408+
6409+
for (User *U : I->users()) {
6410+
auto *Extract = dyn_cast<ExtractValueInst>(U);
6411+
if (!Extract || Extract->getNumIndices() != 1)
6412+
return false;
6413+
6414+
unsigned Index = Extract->getIndices()[0];
6415+
if (Index == 0)
6416+
MulExtract = Extract;
6417+
else if (Index == 1)
6418+
OverflowExtract = Extract;
6419+
else
6420+
return false;
6421+
}
6422+
return true;
6423+
}
6424+
6425+
// Rewrite the mul_with_overflow intrinsic by checking if both of the
6426+
// operands' value ranges are within the legal type. If so, we can optimize the
6427+
// multiplication algorithm. This code is supposed to be written during the step
6428+
// of type legalization, but given that we need to reconstruct the IR which is
6429+
// not doable there, we do it here.
6430+
// The IR after the optimization will look like:
6431+
// entry:
6432+
// if signed:
6433+
// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
6434+
// overflow_no
6435+
// else:
6436+
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
6437+
// overflow_no:
6438+
// overflow:
6439+
// overflow.res:
6440+
// \returns true if optimization was applied
6441+
// TODO: This optimization can be further improved to optimize branching on
6442+
// overflow where the 'overflow_no' BB can branch directly to the false
6443+
// successor of overflow, but that would add additional complexity so we leave
6444+
// it for future work.
6445+
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
6446+
ModifyDT &ModifiedDT) {
6447+
// Check if target supports this optimization.
6448+
if (!TLI->shouldOptimizeMulOverflowWithZeroHighBits(
6449+
I->getContext(),
6450+
TLI->getValueType(*DL, I->getType()->getContainedType(0))))
6451+
return false;
6452+
6453+
ExtractValueInst *MulExtract = nullptr, *OverflowExtract = nullptr;
6454+
if (!matchOverflowPattern(I, MulExtract, OverflowExtract))
6455+
return false;
6456+
6457+
// Keep track of the instruction to stop reoptimizing it again.
6458+
InsertedInsts.insert(I);
6459+
6460+
Value *LHS = I->getOperand(0);
6461+
Value *RHS = I->getOperand(1);
6462+
Type *Ty = LHS->getType();
6463+
unsigned VTHalfBitWidth = Ty->getScalarSizeInBits() / 2;
6464+
Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);
6465+
6466+
// New BBs:
6467+
BasicBlock *OverflowEntryBB =
6468+
I->getParent()->splitBasicBlock(I, "", /*Before*/ true);
6469+
OverflowEntryBB->takeName(I->getParent());
6470+
// Keep the 'br' instruction that is generated as a result of the split to be
6471+
// erased/replaced later.
6472+
Instruction *OldTerminator = OverflowEntryBB->getTerminator();
6473+
BasicBlock *NoOverflowBB =
6474+
BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
6475+
NoOverflowBB->moveAfter(OverflowEntryBB);
6476+
BasicBlock *OverflowBB =
6477+
BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
6478+
OverflowBB->moveAfter(NoOverflowBB);
6479+
6480+
// BB overflow.entry:
6481+
IRBuilder<> Builder(OverflowEntryBB);
6482+
// Extract low and high halves of LHS:
6483+
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
6484+
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
6485+
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
6486+
6487+
// Extract low and high halves of RHS:
6488+
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
6489+
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
6490+
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
6491+
6492+
Value *IsAnyBitTrue;
6493+
if (IsSigned) {
6494+
Value *SignLoLHS =
6495+
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
6496+
Value *SignLoRHS =
6497+
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
6498+
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
6499+
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
6500+
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
6501+
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
6502+
ConstantInt::getNullValue(Or->getType()));
6503+
} else {
6504+
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
6505+
ConstantInt::getNullValue(LegalTy));
6506+
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
6507+
ConstantInt::getNullValue(LegalTy));
6508+
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
6509+
}
6510+
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);
6511+
6512+
// BB overflow.no:
6513+
Builder.SetInsertPoint(NoOverflowBB);
6514+
Value *ExtLoLHS, *ExtLoRHS;
6515+
if (IsSigned) {
6516+
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
6517+
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
6518+
} else {
6519+
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
6520+
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
6521+
}
6522+
6523+
Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");
6524+
6525+
// Create the 'overflow.res' BB to merge the results of
6526+
// the two paths:
6527+
BasicBlock *OverflowResBB = I->getParent();
6528+
OverflowResBB->setName("overflow.res");
6529+
6530+
// BB overflow.no: jump to overflow.res BB
6531+
Builder.CreateBr(OverflowResBB);
6532+
// No we don't need the old terminator in overflow.entry BB, erase it:
6533+
OldTerminator->eraseFromParent();
6534+
6535+
// BB overflow.res:
6536+
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
6537+
// Create PHI nodes to merge results from no.overflow BB and overflow BB to
6538+
// replace the extract instructions.
6539+
PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2),
6540+
*OverflowFlagPHI =
6541+
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);
6542+
6543+
// Add the incoming values from no.overflow BB and later from overflow BB.
6544+
OverflowResPHI->addIncoming(Mul, NoOverflowBB);
6545+
OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()),
6546+
NoOverflowBB);
6547+
6548+
// Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
6549+
if (MulExtract) {
6550+
MulExtract->replaceAllUsesWith(OverflowResPHI);
6551+
MulExtract->eraseFromParent();
6552+
}
6553+
if (OverflowExtract) {
6554+
OverflowExtract->replaceAllUsesWith(OverflowFlagPHI);
6555+
OverflowExtract->eraseFromParent();
6556+
}
6557+
6558+
// Remove the intrinsic from parent (overflow.res BB) as it will be part of
6559+
// overflow BB
6560+
I->removeFromParent();
6561+
// BB overflow:
6562+
I->insertInto(OverflowBB, OverflowBB->end());
6563+
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
6564+
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
6565+
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
6566+
Builder.CreateBr(OverflowResBB);
6567+
6568+
// Add The Extracted values to the PHINodes in the overflow.res BB.
6569+
OverflowResPHI->addIncoming(MulOverflow, OverflowBB);
6570+
OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB);
6571+
6572+
ModifiedDT = ModifyDT::ModifyBBDT;
6573+
return true;
6574+
}
6575+
63946576
/// If there are any memory operands, use OptimizeMemoryInst to sink their
63956577
/// address computing into the block when possible / profitable.
63966578
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18851,6 +18851,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
1885118851
return (Index == 0 || Index == ResVT.getVectorMinNumElements());
1885218852
}
1885318853

18854+
bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits(
18855+
LLVMContext &Context, EVT VT) const {
18856+
if (getTypeAction(Context, VT) != TypeExpandInteger)
18857+
return false;
18858+
18859+
EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2);
18860+
return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal;
18861+
}
18862+
1885418863
/// Turn vector tests of the signbit in the form of:
1885518864
/// xor (sra X, elt_size(X)-1), -1
1885618865
/// into:

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,11 @@ class AArch64TargetLowering : public TargetLowering {
333333
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
334334
}
335335

336+
// Return true if the target wants to optimize the mul overflow intrinsic
337+
// for the given \p VT.
338+
bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
339+
EVT VT) const override;
340+
336341
Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr,
337342
AtomicOrdering Ord) const override;
338343
Value *emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr,

0 commit comments

Comments
 (0)