Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -3470,6 +3470,13 @@ class LLVM_ABI TargetLoweringBase {
return MathUsed && (VT.isSimple() || !isOperationExpand(Opcode, VT));
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
virtual bool shouldOptimizeMulOverflowIntrinsic(LLVMContext &Context,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we need a little more descriptive of a name - something like shouldOptimizeMulOverflowIntrinsicWithHighHalf maybe.

EVT VT) const {
return false;
}

// Return true if it is profitable to use a scalar input to a BUILD_VECTOR
// even if the vector itself has multiple uses.
virtual bool aggressivelyPreferBuildVectorSources(EVT VecVT) const {
Expand Down
303 changes: 303 additions & 0 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ class CodeGenPrepare {
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy,
unsigned AddrSpace);
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT);
bool optimizeInlineAsmInst(CallInst *CS);
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
bool optimizeExt(Instruction *&I);
Expand Down Expand Up @@ -2778,6 +2780,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
}
}
return false;
case Intrinsic::umul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
case Intrinsic::smul_with_overflow:
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
}

SmallVector<Value *, 2> PtrOps;
Expand Down Expand Up @@ -6389,6 +6395,303 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
return true;
}

// Rewrite the mul_with_overflow intrinsic by checking if both of the
// operands' value range is within the legal type. If so, we can optimize the
// multiplication algorithm. This code is supposed to be written during the step
// of type legalization, but given that we need to reconstruct the IR which is
// not doable there, we do it here.
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
ModifyDT &ModifiedDT) {
if (!TLI->shouldOptimizeMulOverflowIntrinsic(
I->getContext(),
TLI->getValueType(*DL, I->getType()->getContainedType(0))))
return false;

Value *LHS = I->getOperand(0);
Value *RHS = I->getOperand(1);
Type *Ty = LHS->getType();
unsigned VTBitWidth = Ty->getScalarSizeInBits();
unsigned VTHalfBitWidth = VTBitWidth / 2;
IntegerType *LegalTy =
IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps Type *LegalTy = Ty->getWithNewBitWidth(VTHalfBitWidth);


// Skip the optimization if the type with HalfBitWidth is not legal for the
// target.
if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed now?

TargetLowering::TypeLegal)
return false;

// Check the pattern we are interested in where there are maximum 2 uses
// of the intrinsic which are the extracts instructions.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extract instructions

if (I->getNumUses() > 2)
return false;
ExtractValueInst *MulExtract = nullptr;
ExtractValueInst *OverflowExtract = nullptr;
for (User *U : I->users()) {
auto *Extract = dyn_cast<ExtractValueInst>(U);
if (!Extract)
return false;

unsigned Index = Extract->getIndices()[0];
if (Index == 0)
MulExtract = Extract;
else if (Index == 1)
OverflowExtract = Extract;
}

// Keep track of the instruction to stop reoptimizing it again.
InsertedInsts.insert(I);
// ----------------------------
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this ---- line be removed?


// For the simple case where IR just checks the overflow flag, new blocks
// should be:
// entry:
// if signed:
// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
// overflow_no
// else:
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
// overflow_no:
// overflow:

// otherwise, new blocks should be:
// entry:
// if signed:
// ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
// overflow_no
// else:
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
// overflow_no:
// overflow:
// overflow.res:

// New BBs:
std::string KeepBBName = I->getParent()->getName().str();
BasicBlock *OverflowEntryBB =
I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
// Remove the 'br' instruction that is generated as a result of the split
// as we are going to append new instructions.
OverflowEntryBB->getTerminator()->eraseFromParent();
BasicBlock *NoOverflowBB =
BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction());
NoOverflowBB->moveAfter(OverflowEntryBB);
BasicBlock *OverflowBB =
BasicBlock::Create(I->getContext(), "overflow", I->getFunction());
OverflowBB->moveAfter(NoOverflowBB);

// BB overflow.entry:
IRBuilder<> Builder(OverflowEntryBB);
// Get Lo and Hi of LHS & RHS:
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");

Value *IsAnyBitTrue;
if (IsSigned) {
Value *SignLoLHS =
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
Value *SignLoRHS =
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or,
ConstantInt::getNullValue(Or->getType()));
} else {
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
ConstantInt::getNullValue(LegalTy));
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
ConstantInt::getNullValue(LegalTy));
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
}
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);

// BB overflow.no:
Builder.SetInsertPoint(NoOverflowBB);
Value *ExtLoLHS, *ExtLoRHS;
if (IsSigned) {
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
} else {
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
}

Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no");

// In overflow.no BB: we are sure that the overflow flag is false.
// So, when we find this pattern:
// br (extractvalue (%mul, 1)), label %if.then, label %if.end
// then we can jump directly to %if.end as we're sure that there is no
// overflow. This is checking the simple case where the exiting br of I's BB
// is the branch we are interested in.
BasicBlock *NoOverflowBrBB = nullptr;
if (auto *Br = dyn_cast<BranchInst>(I->getParent()->getTerminator())) {
// Check that the Br is testing the overflow bit:
if (Br->isConditional()) {
auto *ExtInstr = dyn_cast<ExtractValueInst>(Br->getOperand(0));
if (ExtInstr && ExtInstr->getIndices()[0] == 1)
NoOverflowBrBB = Br->getSuccessor(1) /*if.end*/;
}
}
if (NoOverflowBrBB) {
// Duplicate instructions from I's BB to the NoOverflowBB:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a limit on the number of instructions we should duplicate here? It could be quite a few. There are certain cannotDuplicate instructions to watch out for too.

ValueToValueMapTy VMap;
for (auto It = std::next(BasicBlock::iterator(I));
&*It != I->getParent()->getTerminator(); ++It) {
Instruction *OrigInst = &*It;
if (isa<DbgInfoIntrinsic>(OrigInst) || OrigInst == MulExtract ||
OrigInst == OverflowExtract)
continue;
Instruction *NewInst = nullptr;
NewInst = OrigInst->clone();
Builder.Insert(NewInst);
VMap[OrigInst] = NewInst;
RemapInstruction(NewInst, VMap, RF_IgnoreMissingLocals);
}
// Replace uses of MulExtract at the 'overflow.no' BB
if (MulExtract)
MulExtract->replaceUsesWithIf(Mul, [&](Use &U) {
return cast<Instruction>(U.getUser())->getParent() == NoOverflowBB;
});
if (OverflowExtract)
// Overflow flag is always false as we are sure it's not overflow.
OverflowExtract->replaceUsesWithIf(
ConstantInt::getFalse(I->getContext()), [&](Use &U) {
return cast<Instruction>(U.getUser())->getParent() == NoOverflowBB;
});
// BB overflow.no: jump directly to if.end BB
Builder.CreateBr(NoOverflowBrBB);

// Remove the original BB as it's divided into 'overflow.entry' and
// another BB where I exists.
BasicBlock *ToBeRemoveBB = I->getParent();
// BB overflow:
// Merge the original BB of I into the 'overflow' BB:
OverflowBB->splice(OverflowBB->end(), ToBeRemoveBB);

// Check if the Br BB has a PHI node and I->getParent() is one of
// its incoming BBs:
PHINode *PN = nullptr;
for (auto It = NoOverflowBrBB->begin(); It != NoOverflowBrBB->end(); ++It) {
if (!isa<PHINode>(&*It))
break;
PN = cast<PHINode>(&*It);
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
if (PN->getIncomingBlock(i) == ToBeRemoveBB) {
// Replace the old block by the new 'overflow' BB:
PN->setIncomingBlock(i, OverflowBB);
Value *IncomingValue = PN->getIncomingValue(i);
// Check if the incoming value is a constant, duplicate it.
if (isa<Constant>(IncomingValue)) {
PN->addIncoming(IncomingValue, NoOverflowBB);
continue;
}
// Check if this instruction was cloned to the 'overflow.no' BB:
Instruction *ClonedInstr =
cast_or_null<Instruction>(VMap.lookup(IncomingValue));
if (ClonedInstr) {
PN->addIncoming(ClonedInstr, NoOverflowBB);
continue;
} else if (isa<Instruction>(IncomingValue)) {
if (cast<Instruction>(IncomingValue) == MulExtract) {
PN->addIncoming(Mul, NoOverflowBB);
continue;
}
if (cast<Instruction>(IncomingValue) == OverflowExtract) {
PN->addIncoming(ConstantInt::getFalse(I->getContext()),
NoOverflowBB);
continue;
}
}
llvm_unreachable("Unexpected incoming value to PHI node");
}
}
}
if (!PN) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check that PN was the MulExtract or something? We could be checking multiple unrelated phis above.
And could OverflowExtract have other uses?

// Create PHI node to get the results of multiplication from 'overflow.no'
// and 'overflow' BBs:
if (MulExtract) {
Builder.SetInsertPoint(NoOverflowBrBB,
NoOverflowBrBB->getFirstInsertionPt());
PN = Builder.CreatePHI(Ty, 2);
PN->addIncoming(Mul, NoOverflowBB);
if (MulExtract->getParent() == OverflowBB) {
// Replace all uses of MulExtract out of 'Overflow' BB
MulExtract->replaceUsesWithIf(PN, [&](Use &U) {
return cast<Instruction>(U.getUser())->getParent() != OverflowBB;
});
PN->addIncoming(MulExtract, OverflowBB);
} else {
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
Value *IntrinsicMulRes =
Builder.CreateExtractValue(I, {0}, "mul.extract");
cast<Instruction>(IntrinsicMulRes)->moveAfter(I);
PN->addIncoming(IntrinsicMulRes, OverflowBB);
MulExtract->replaceAllUsesWith(PN);
MulExtract->eraseFromParent();
}
}
}

ToBeRemoveBB->eraseFromParent();
// Restore the original name of the overflow.entry BB:
OverflowEntryBB->setName(KeepBBName);

ModifiedDT = ModifyDT::ModifyBBDT;
return true;
}

// Otherwise, we need to create the 'overflow.res' BB to merge the results of
// the two paths:
I->getParent()->setName("overflow.res");
BasicBlock *OverflowResBB = I->getParent();

// BB overflow.no: jump to overflow.res BB
Builder.CreateBr(OverflowResBB);

// BB overflow.res:
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2),
*OverflowFlagPHI =
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);

OverflowResPHI->addIncoming(Mul, NoOverflowBB);
OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()),
NoOverflowBB);

// Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
if (MulExtract) {
MulExtract->replaceAllUsesWith(OverflowResPHI);
MulExtract->eraseFromParent();
}
if (OverflowExtract) {
OverflowExtract->replaceAllUsesWith(OverflowFlagPHI);
OverflowExtract->eraseFromParent();
}
I->removeFromParent();

// BB overflow:
I->insertInto(OverflowBB, OverflowBB->end());
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
Builder.CreateBr(OverflowResBB);

// Add The Extracted values to the PHINodes in the overflow.res block.
OverflowResPHI->addIncoming(MulOverflow, OverflowBB);
OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB);

// Restore the original name of the overflow.entry BB:
OverflowEntryBB->setName(KeepBBName);

ModifiedDT = ModifyDT::ModifyBBDT;
return true;
}

/// If there are any memory operands, use OptimizeMemoryInst to sink their
/// address computing into the block when possible / profitable.
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) {
Expand Down
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,13 @@ class AArch64TargetLowering : public TargetLowering {
return TargetLowering::shouldFormOverflowOp(Opcode, VT, true);
}

// Return true if the target wants to optimize the mul overflow intrinsic
// for the given \p VT.
bool shouldOptimizeMulOverflowIntrinsic(LLVMContext &Context,
EVT VT) const override {
return getTypeAction(Context, VT) == TypeExpandInteger;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this just mean VT == MVT::i64? Does it apply to vector types? It might be simpler to be explicit.

}

Value *emitLoadLinked(IRBuilderBase &Builder, Type *ValueTy, Value *Addr,
AtomicOrdering Ord) const override;
Value *emitStoreConditional(IRBuilderBase &Builder, Value *Val, Value *Addr,
Expand Down
Loading