Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,13 @@ class LLVM_ABI TargetLoweringBase {
return MathUsed && (VT.isSimple() || !isOperationExpand(Opcode, VT));
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
virtual bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
EVT VT) const {
return false;
}

// Return true if it is profitable to use a scalar input to a BUILD_VECTOR
// even if the vector itself has multiple uses.
virtual bool aggressivelyPreferBuildVectorSources(EVT VecVT) const {
Expand Down
184 changes: 184 additions & 0 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
Expand Down Expand Up @@ -2778,6 +2780,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
}
}
return false;
case Intrinsic::umul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
case Intrinsic::smul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
}

SmallVector<Value *, 2> PtrOps;
Expand Down Expand Up @@ -6389,6 +6395,184 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}

// This is a helper for CodeGenPrepare::optimizeMulWithOverflow.
// Check the pattern we are interested in where there are maximum 2 uses
// of the intrinsic which are the extract instructions.
static bool matchOverflowPattern(Instruction *&I, ExtractValueInst *&MulExtract,
ExtractValueInst *&OverflowExtract) {
if (I->getNumUses() > 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasNUsesOrMore

return false;

for (User *U : I->users()) {
auto *Extract = dyn_cast<ExtractValueInst>(U);
if (!Extract || Extract->getNumIndices() != 1)
return false;

unsigned Index = Extract->getIndices()[0];
if (Index == 0)
MulExtract = Extract;
else if (Index == 1)
OverflowExtract = Extract;
else
return false;
}
return true;
}

// Rewrite the mul_with_overflow intrinsic by checking if both of the
// operands' value ranges are within the legal type. If so, we can optimize the
// multiplication algorithm. This code is supposed to be written during the step
// of type legalization, but given that we need to reconstruct the IR which is
// not doable there, we do it here.
// The IR after the optimization will look like:
// entry:
// if signed:
// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
// overflow_no
// else:
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
// overflow_no:
// overflow:
// overflow.res:
// \returns true if optimization was applied
// TODO: This optimization can be further improved to optimize branching on
// overflow where the 'overflow_no' BB can branch directly to the false
// successor of overflow, but that would add additional complexity so we leave
// it for future work.
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT) {
// Check if target supports this optimization.
if (!TLI->shouldOptimizeMulOverflowWithZeroHighBits(
I->getContext(),
TLI->getValueType(*DL, I->getType()->getContainedType(0))))
return false;

ExtractValueInst *MulExtract = nullptr, *OverflowExtract = nullptr;
if (!matchOverflowPattern(I, MulExtract, OverflowExtract))
return false;

// Keep track of the instruction to stop reoptimizing it again.
InsertedInsts.insert(I);

Value *LHS = I->getOperand(0);
Value *RHS = I->getOperand(1);
Type *Ty = LHS->getType();
unsigned VTHalfBitWidth = Ty->getScalarSizeInBits() / 2;
Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);

// New BBs:
std::string OriginalBlockName = I->getParent()->getName().str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to duplicate the string if it's only used in for OverflowEntryBB, created right below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to cache the name because if I renamed the new block by the original name, it will be numbered -entry1-, because the name is already used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

takeName?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't know this. I use it, thanks.

BasicBlock *OverflowEntryBB =
I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
// Keep the 'br' instruction that is generated as a result of the split to be
// erased/replaced later.
Instruction *OldTerminator = OverflowEntryBB->getTerminator();
BasicBlock *NoOverflowBB =
BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
NoOverflowBB->moveAfter(OverflowEntryBB);
BasicBlock *OverflowBB =
BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
OverflowBB->moveAfter(NoOverflowBB);

// BB overflow.entry:
IRBuilder<> Builder(OverflowEntryBB);
// Extract low and high halves of LHS:
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");

// Extract low and high halves of RHS:
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");

Value *IsAnyBitTrue;
if (IsSigned) {
Value *SignLoLHS =
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
Value *SignLoRHS =
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
ConstantInt::getNullValue(Or->getType()));
} else {
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
ConstantInt::getNullValue(LegalTy));
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
ConstantInt::getNullValue(LegalTy));
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
}
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);

// BB overflow.no:
Builder.SetInsertPoint(NoOverflowBB);
Value *ExtLoLHS, *ExtLoRHS;
if (IsSigned) {
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
} else {
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
}

Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");

// Create the 'overflow.res' BB to merge the results of
// the two paths:
BasicBlock *OverflowResBB = I->getParent();
OverflowResBB->setName("overflow.res");

// BB overflow.no: jump to overflow.res BB
Builder.CreateBr(OverflowResBB);
// No we don't need the old terminator in overflow.entry BB, erase it:
OldTerminator->eraseFromParent();

// BB overflow.res:
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
// Create PHI nodes to merge results from no.overflow BB and overflow BB to
// replace the extract instructions.
PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2),
*OverflowFlagPHI =
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);

// Add the incoming values from no.overflow BB and later from overflow BB.
OverflowResPHI->addIncoming(Mul, NoOverflowBB);
OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()),
NoOverflowBB);

// Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
if (MulExtract) {
MulExtract->replaceAllUsesWith(OverflowResPHI);
MulExtract->eraseFromParent();
}
if (OverflowExtract) {
OverflowExtract->replaceAllUsesWith(OverflowFlagPHI);
OverflowExtract->eraseFromParent();
}

// Remove the intrinsic from parent (overflow.res BB) as it will be part of
// overflow BB
I->removeFromParent();
// BB overflow:
I->insertInto(OverflowBB, OverflowBB->end());
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
Builder.CreateBr(OverflowResBB);

// Add The Extracted values to the PHINodes in the overflow.res BB.
OverflowResPHI->addIncoming(MulOverflow, OverflowBB);
OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB);

// Restore the original name of the overflow.entry BB:
OverflowEntryBB->setName(OriginalBlockName);

ModifiedDT = ModifyDT::ModifyBBDT;
return true;
}

/// If there are any memory operands, use OptimizeMemoryInst to sink their
/// address computing into the block when possible / profitable.
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18542,6 +18542,15 @@ bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
return (Index == 0 || Index == ResVT.getVectorMinNumElements());
}

bool AArch64TargetLowering::shouldOptimizeMulOverflowWithZeroHighBits(
LLVMContext &Context, EVT VT) const {
if (getTypeAction(Context, VT) != TypeExpandInteger)
return false;

EVT LegalTy = EVT::getIntegerVT(Context, VT.getSizeInBits() / 2);
return getTypeAction(Context, LegalTy) == TargetLowering::TypeLegal;
}

/// Turn vector tests of the signbit in the form of:
/// xor (sra X, elt_size(X)-1), -1
/// into:
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,11 @@ class AArch64TargetLowering : public TargetLowering {
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
bool shouldOptimizeMulOverflowWithZeroHighBits(LLVMContext &Context,
EVT VT) const override;

Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const override;
Value *emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr,
Expand Down
Loading
Loading