-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[CGP]: Optimize mul.overflow. #148343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CGP]: Optimize mul.overflow. #148343
Changes from all commits
b4ac552
f8b60eb
cd2815f
cd23298
7d1df2f
9531133
6ecfd1f
d878724
0610edf
9fa4927
08c498c
8607d5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -431,6 +431,8 @@ class CodeGenPrepare { | |
bool optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, Type *AccessTy, | ||
unsigned AddrSpace); | ||
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr); | ||
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned, | ||
ModifyDT &ModifiedDT); | ||
bool optimizeInlineAsmInst(CallInst *CS); | ||
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT); | ||
bool optimizeExt(Instruction *&I); | ||
|
@@ -2778,6 +2780,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) { | |
} | ||
} | ||
return false; | ||
case Intrinsic::umul_with_overflow: | ||
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT); | ||
case Intrinsic::smul_with_overflow: | ||
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT); | ||
} | ||
|
||
SmallVector<Value *, 2> PtrOps; | ||
|
@@ -6389,6 +6395,235 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, | |
return true; | ||
} | ||
|
||
// Rewrite the mul_with_overflow intrinsic by checking if both of the | ||
// operands' value range is within the legal type. If so, we can optimize the | ||
// multiplication algorithm. This code is supposed to be written during the step | ||
// of type legalization, but given that we need to reconstruct the IR which is | ||
// not doable there, we do it here. | ||
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned, | ||
ModifyDT &ModifiedDT) { | ||
if (!TLI->shouldOptimizeMulOverflowIntrinsic()) | ||
return false; | ||
|
||
if (TLI->getTypeAction( | ||
I->getContext(), | ||
TLI->getValueType(*DL, I->getType()->getContainedType(0))) != | ||
TargetLowering::TypeExpandInteger) | ||
return false; | ||
|
||
Value *LHS = I->getOperand(0); | ||
Value *RHS = I->getOperand(1); | ||
Type *Ty = LHS->getType(); | ||
unsigned VTBitWidth = Ty->getScalarSizeInBits(); | ||
unsigned VTHalfBitWidth = VTBitWidth / 2; | ||
IntegerType *LegalTy = | ||
IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth); | ||
|
||
// Skip the optimization if the type with HalfBitWidth is not legal for the | ||
// target. | ||
if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) != | ||
TargetLowering::TypeLegal) | ||
return false; | ||
|
||
// Make sure that the I->getType() is a struct type with two elements. | ||
if (!I->getType()->isStructTy() || I->getType()->getStructNumElements() != 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should always be true? |
||
return false; | ||
|
||
// Keep track of the instruction to stop reoptimizing it again. | ||
InsertedInsts.insert(I); | ||
// ---------------------------- | ||
|
||
// For the simple case where IR just checks the overflow flag, new blocks | ||
// should be: | ||
// entry: | ||
// if signed: | ||
// (lhs_lo ^ lhs_hi) || (rhs_lo ^ rhs_hi) ? overflow, overflow_no | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does there needs to be a |
||
// else: | ||
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no | ||
// overflow_no: | ||
// overflow: | ||
|
||
// otherwise, new blocks should be: | ||
// entry: | ||
// if signed: | ||
// (lhs_lo ^ lhs_hi) || (rhs_lo ^ rhs_hi) ? overflow, overflow_no | ||
// else: | ||
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no | ||
// overflow_no: | ||
// overflow: | ||
// overflow.res: | ||
|
||
// New BBs: | ||
std::string KeepBBName = I->getParent()->getName().str(); | ||
BasicBlock *OverflowEntryBB = | ||
I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true); | ||
// Remove the 'br' instruction that is generated as a result of the split: | ||
OverflowEntryBB->getTerminator()->eraseFromParent(); | ||
BasicBlock *NoOverflowBB = | ||
BasicBlock::Create(I->getContext(), "overflow.no", I->getFunction()); | ||
NoOverflowBB->moveAfter(OverflowEntryBB); | ||
BasicBlock *OverflowBB = | ||
BasicBlock::Create(I->getContext(), "overflow", I->getFunction()); | ||
OverflowBB->moveAfter(NoOverflowBB); | ||
|
||
// BB overflow.entry: | ||
IRBuilder<> Builder(OverflowEntryBB); | ||
// Get Lo and Hi of LHS & RHS: | ||
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs"); | ||
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr"); | ||
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs"); | ||
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs"); | ||
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr"); | ||
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs"); | ||
|
||
Value *IsAnyBitTrue; | ||
if (IsSigned) { | ||
Value *SignLoLHS = | ||
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs"); | ||
Value *SignLoRHS = | ||
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs"); | ||
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS); | ||
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS); | ||
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs"); | ||
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_NE, Or, | ||
ConstantInt::getNullValue(Or->getType())); | ||
} else { | ||
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS, | ||
ConstantInt::getNullValue(LegalTy)); | ||
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS, | ||
ConstantInt::getNullValue(LegalTy)); | ||
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs"); | ||
} | ||
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB); | ||
|
||
// BB overflow.no: | ||
Builder.SetInsertPoint(NoOverflowBB); | ||
Value *ExtLoLHS, *ExtLoRHS; | ||
if (IsSigned) { | ||
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext"); | ||
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext"); | ||
} else { | ||
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext"); | ||
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext"); | ||
} | ||
|
||
Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.overflow.no"); | ||
|
||
// In overflow.no BB: we are sure that the overflow flag is false. | ||
// So, if we found this pattern: | ||
// br (extractvalue (%mul, 1)), label %if.then, label %if.end | ||
// then we can jump directly to %if.end as we're sure that there is no | ||
// overflow. | ||
BasicBlock *DetectNoOverflowBrBB = nullptr; | ||
StructType *STy = StructType::get( | ||
I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())}); | ||
// Look for the pattern in the users of I, and make sure that all the users | ||
// are either part of the pattern or NOT in the same BB as I. | ||
for (User *U : I->users()) { | ||
if (auto *Instr = dyn_cast<Instruction>(U); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be worth grabbing the 2 users of the MulWithOverflow (they should be the two extracts, if there are more bail out). Then we can replace the first with the result of the mul (through the new phi), and the overflow with the second part. That way we do not need to create the InsertValue. The case where the overflow bit is used in a branch needs to make sure there are no other instruction between the mul and the branch (other than debug instructions). Otherwise they might need to be duplicated. A User of an instruction will always be an Instruction. |
||
Instr && Instr->getParent() != I->getParent()) | ||
continue; | ||
|
||
if (auto *ExtUser = dyn_cast<ExtractValueInst>(U)) { | ||
if (ExtUser->hasOneUse() && ExtUser->getNumIndices() == 1 && | ||
ExtUser->getIndices()[0] == 1) { | ||
if (auto *Br = dyn_cast<BranchInst>(*ExtUser->user_begin())) { | ||
DetectNoOverflowBrBB = Br->getSuccessor(1) /*if.end*/; | ||
continue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. continue->break, if we found what we are interested in? |
||
} | ||
} | ||
} | ||
// If we come here, it means that either the pattern doesn't exist or | ||
// there are multiple users in the same BB | ||
DetectNoOverflowBrBB = nullptr; | ||
break; | ||
Comment on lines
+6536
to
+6539
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is needed, and we can look through all the users. There should only be 2 but we might need to look at the second. |
||
} | ||
if (DetectNoOverflowBrBB) { | ||
// BB overflow.no: jump directly to if.end BB | ||
Builder.CreateBr(DetectNoOverflowBrBB); | ||
|
||
// BB if.end: | ||
Builder.SetInsertPoint(DetectNoOverflowBrBB, | ||
DetectNoOverflowBrBB->getFirstInsertionPt()); | ||
// Create PHI node to get the results of multiplication from 'overflow.no' | ||
// and 'overflow' BBs | ||
PHINode *NoOverflowPHI = Builder.CreatePHI(Ty, 2); | ||
NoOverflowPHI->addIncoming(Mul, NoOverflowBB); | ||
// Create struct value to replace all uses of I | ||
Value *StructValNoOverflow = PoisonValue::get(STy); | ||
StructValNoOverflow = | ||
Builder.CreateInsertValue(StructValNoOverflow, NoOverflowPHI, {0}); | ||
// Overflow flag is always false as we are sure it's not overflow. | ||
StructValNoOverflow = Builder.CreateInsertValue( | ||
StructValNoOverflow, ConstantInt::getFalse(I->getContext()), {1}); | ||
// Replace all uses of I, only uses dominated by the if.end BB | ||
I->replaceUsesOutsideBlock(StructValNoOverflow, I->getParent()); | ||
|
||
// Remove the original BB as it's divided into 'overflow.entry' and | ||
// 'overflow' BBs. | ||
BasicBlock *ToBeRemoveBB = I->getParent(); | ||
// BB overflow: | ||
OverflowBB->splice(OverflowBB->end(), ToBeRemoveBB); | ||
// Extract the multiplication result to add it to the PHI node in the if.end | ||
// BB | ||
Builder.SetInsertPoint(OverflowBB, OverflowBB->end()); | ||
Value *IntrinsicMulRes = Builder.CreateExtractValue(I, {0}, "mul.extract"); | ||
cast<Instruction>(IntrinsicMulRes)->moveAfter(I); | ||
NoOverflowPHI->addIncoming(IntrinsicMulRes, OverflowBB); | ||
|
||
ToBeRemoveBB->eraseFromParent(); | ||
// Restore the original name of the overflow.entry BB: | ||
OverflowEntryBB->setName(KeepBBName); | ||
ModifiedDT = ModifyDT::ModifyBBDT; | ||
return true; | ||
} | ||
|
||
// Otherwise, we need to create the 'overflow.res' BB to merge the results of | ||
// the two paths: | ||
I->getParent()->setName("overflow.res"); | ||
BasicBlock *OverflowResBB = I->getParent(); | ||
|
||
// BB overflow.no: jump to overflow.res BB | ||
Builder.CreateBr(OverflowResBB); | ||
|
||
// BB overflow.res: | ||
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt()); | ||
PHINode *OverflowResPHI = Builder.CreatePHI(Ty, 2), | ||
*OverflowFlagPHI = | ||
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2); | ||
|
||
Value *StructValOverflowRes = PoisonValue::get(STy); | ||
StructValOverflowRes = | ||
Builder.CreateInsertValue(StructValOverflowRes, OverflowResPHI, {0}); | ||
StructValOverflowRes = | ||
Builder.CreateInsertValue(StructValOverflowRes, OverflowFlagPHI, {1}); | ||
OverflowResPHI->addIncoming(Mul, NoOverflowBB); | ||
OverflowFlagPHI->addIncoming(ConstantInt::getFalse(I->getContext()), | ||
NoOverflowBB); | ||
|
||
// Before moving the mul.overflow intrinsic to the overflowBB, replace all its | ||
// uses by StructValOverflowRes. | ||
I->replaceAllUsesWith(StructValOverflowRes); | ||
I->removeFromParent(); | ||
|
||
// BB overflow: | ||
I->insertInto(OverflowBB, OverflowBB->end()); | ||
Builder.SetInsertPoint(OverflowBB, OverflowBB->end()); | ||
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow"); | ||
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag"); | ||
Builder.CreateBr(OverflowResBB); | ||
|
||
// Add The Extracted values to the PHINodes in the overflow.res block. | ||
OverflowResPHI->addIncoming(MulOverflow, OverflowBB); | ||
OverflowFlagPHI->addIncoming(OverflowFlag, OverflowBB); | ||
|
||
// Restore the original name of the overflow.entry BB: | ||
OverflowEntryBB->setName(KeepBBName); | ||
|
||
ModifiedDT = ModifyDT::ModifyBBDT; | ||
return true; | ||
} | ||
|
||
/// If there are any memory operands, use OptimizeMemoryInst to sink their | ||
/// address computing into the block when possible / profitable. | ||
bool CodeGenPrepare::optimizeInlineAsmInst(CallInst *CS) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would pass the type into the shouldOptimizeMulOverflowIntrinsic, and have the target return whether it should expand for the given type. Some of the legality checks bellow could then be removed from here.