Skip to content

Commit 9531133

Browse files
committed
resolve review comments
1 parent 7d1df2f commit 9531133

File tree

1 file changed

+70
-147
lines changed

1 file changed

+70
-147
lines changed

llvm/lib/CodeGen/CodeGenPrepare.cpp

Lines changed: 70 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,10 @@ class CodeGenPrepare {
338338
/// Keep track of instructions removed during promotion.
339339
SetOfInstrs RemovedInsts;
340340

341+
/// Keep track of seen mul_with_overflow intrinsics to avoid
342+
// reprocessing them.
343+
DenseMap<Instruction *, bool> SeenMulWithOverflowInstrs;
344+
341345
/// Keep track of sext chains based on their initial value.
342346
DenseMap<Value *, Instruction *> SeenChainsForSExt;
343347

@@ -433,6 +437,8 @@ class CodeGenPrepare {
433437
bool optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr);
434438
bool optimizeUMulWithOverflow(Instruction *I);
435439
bool optimizeSMulWithOverflow(Instruction *I);
440+
bool optimizeMulWithOverflow(Instruction *I, bool IsSigned,
441+
ModifyDT &ModifiedDT);
436442
bool optimizeInlineAsmInst(CallInst *CS);
437443
bool optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT);
438444
bool optimizeExt(Instruction *&I);
@@ -774,6 +780,7 @@ bool CodeGenPrepare::_run(Function &F) {
774780
verifyBFIUpdates(F);
775781
#endif
776782

783+
SeenMulWithOverflowInstrs.clear();
777784
return EverMadeChange;
778785
}
779786

@@ -2781,9 +2788,9 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
27812788
}
27822789
return false;
27832790
case Intrinsic::umul_with_overflow:
2784-
return optimizeUMulWithOverflow(II);
2791+
return optimizeMulWithOverflow(II, /*IsSigned=*/false, ModifiedDT);
27852792
case Intrinsic::smul_with_overflow:
2786-
return optimizeSMulWithOverflow(II);
2793+
return optimizeMulWithOverflow(II, /*IsSigned=*/true, ModifiedDT);
27872794
}
27882795

27892796
SmallVector<Value *, 2> PtrOps;
@@ -6395,122 +6402,20 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
63956402
return true;
63966403
}
63976404

6398-
// Rewrite the umul_with_overflow intrinsic by checking if both of the
6405+
// Rewrite the mul_with_overflow intrinsic by checking if both of the
63996406
// operands' value range is within the legal type. If so, we can optimize the
64006407
// multiplication algorithm. This code is supposed to be written during the step
64016408
// of type legalization, but given that we need to reconstruct the IR which is
64026409
// not doable there, we do it here.
6403-
bool CodeGenPrepare::optimizeUMulWithOverflow(Instruction *I) {
6410+
bool CodeGenPrepare::optimizeMulWithOverflow(Instruction *I, bool IsSigned,
6411+
ModifyDT &ModifiedDT) {
64046412
// Enable this optimization only for aarch64.
64056413
if (!TLI->getTargetMachine().getTargetTriple().isAArch64())
64066414
return false;
6407-
if (TLI->getTypeAction(
6408-
I->getContext(),
6409-
TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
6410-
TargetLowering::TypeExpandInteger)
6411-
return false;
6412-
6413-
Value *LHS = I->getOperand(0);
6414-
Value *RHS = I->getOperand(1);
6415-
auto *Ty = LHS->getType();
6416-
unsigned VTBitWidth = Ty->getScalarSizeInBits();
6417-
unsigned VTHalfBitWidth = VTBitWidth / 2;
6418-
auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
6419-
6420-
// Skip the optimization if the type with HalfBitWidth is not legal for the
6421-
// target.
6422-
if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
6423-
TargetLowering::TypeLegal)
6415+
// If we have already seen this instruction, don't process it again.
6416+
if (!SeenMulWithOverflowInstrs.insert(std::make_pair(I, true)).second)
64246417
return false;
64256418

6426-
I->getParent()->setName("overflow.res");
6427-
auto *OverflowResBB = I->getParent();
6428-
auto *OverflowoEntryBB =
6429-
I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
6430-
BasicBlock *NoOverflowBB = BasicBlock::Create(
6431-
I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
6432-
BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
6433-
I->getFunction(), OverflowResBB);
6434-
// new blocks should be:
6435-
// entry:
6436-
// (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
6437-
6438-
// overflow_no:
6439-
// overflow:
6440-
// overflow.res:
6441-
//------------------------------------------------------------------------------
6442-
// BB overflow.entry:
6443-
// get Lo and Hi of RHS & LHS:
6444-
IRBuilder<> Builder(OverflowoEntryBB->getTerminator());
6445-
auto *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs.trunc");
6446-
auto *ShrHiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
6447-
auto *HiRHS = Builder.CreateTrunc(ShrHiRHS, LegalTy, "hi.rhs.trunc");
6448-
6449-
auto *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs.trunc");
6450-
auto *ShrHiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
6451-
auto *HiLHS = Builder.CreateTrunc(ShrHiLHS, LegalTy, "hi.lhs.trunc");
6452-
6453-
auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
6454-
ConstantInt::getNullValue(LegalTy));
6455-
auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
6456-
ConstantInt::getNullValue(LegalTy));
6457-
auto *Or = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
6458-
Builder.CreateCondBr(Or, OverflowBB, NoOverflowBB);
6459-
OverflowoEntryBB->getTerminator()->eraseFromParent();
6460-
6461-
//------------------------------------------------------------------------------
6462-
// BB overflow.no:
6463-
Builder.SetInsertPoint(NoOverflowBB);
6464-
auto *ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
6465-
auto *ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
6466-
auto *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
6467-
Builder.CreateBr(OverflowResBB);
6468-
6469-
//------------------------------------------------------------------------------
6470-
// BB overflow.res:
6471-
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
6472-
auto *PHINode1 = Builder.CreatePHI(Ty, 2);
6473-
PHINode1->addIncoming(Mul, NoOverflowBB);
6474-
auto *PHINode2 =
6475-
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);
6476-
PHINode2->addIncoming(ConstantInt::getFalse(I->getContext()), NoOverflowBB);
6477-
6478-
StructType *STy = StructType::get(
6479-
I->getContext(), {Ty, IntegerType::getInt1Ty(I->getContext())});
6480-
Value *StructValOverflowRes = PoisonValue::get(STy);
6481-
StructValOverflowRes =
6482-
Builder.CreateInsertValue(StructValOverflowRes, PHINode1, {0});
6483-
StructValOverflowRes =
6484-
Builder.CreateInsertValue(StructValOverflowRes, PHINode2, {1});
6485-
// Before moving the mul.overflow intrinsic to the overflowBB, replace all its
6486-
// uses by StructValOverflowRes.
6487-
I->replaceAllUsesWith(StructValOverflowRes);
6488-
I->removeFromParent();
6489-
6490-
// BB overflow:
6491-
I->insertInto(OverflowBB, OverflowBB->end());
6492-
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
6493-
auto *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
6494-
auto *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
6495-
Builder.CreateBr(OverflowResBB);
6496-
6497-
// Add The Extracted values to the PHINodes in the overflow.res block.
6498-
PHINode1->addIncoming(MulOverflow, OverflowBB);
6499-
PHINode2->addIncoming(OverflowFlag, OverflowBB);
6500-
6501-
// return false to stop reprocessing the function.
6502-
return false;
6503-
}
6504-
6505-
// Rewrite the smul_with_overflow intrinsic by checking if both of the
6506-
// operands' value range is within the legal type. If so, we can optimize the
6507-
// multiplication algorithm. This code is supposed to be written during the step
6508-
// of type legalization, but given that we need to reconstruct the IR which is
6509-
// not doable there, we do it here.
6510-
bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
6511-
// Enable this optimization only for aarch64.
6512-
if (!TLI->getTargetMachine().getTargetTriple().isAArch64())
6513-
return false;
65146419
if (TLI->getTypeAction(
65156420
I->getContext(),
65166421
TLI->getValueType(*DL, I->getType()->getContainedType(0))) !=
@@ -6519,75 +6424,93 @@ bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
65196424

65206425
Value *LHS = I->getOperand(0);
65216426
Value *RHS = I->getOperand(1);
6522-
auto *Ty = LHS->getType();
6427+
Type *Ty = LHS->getType();
65236428
unsigned VTBitWidth = Ty->getScalarSizeInBits();
65246429
unsigned VTHalfBitWidth = VTBitWidth / 2;
6525-
auto *LegalTy = IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
6430+
IntegerType *LegalTy =
6431+
IntegerType::getIntNTy(I->getContext(), VTHalfBitWidth);
65266432

65276433
// Skip the optimization if the type with HalfBitWidth is not legal for the
65286434
// target.
65296435
if (TLI->getTypeAction(I->getContext(), TLI->getValueType(*DL, LegalTy)) !=
65306436
TargetLowering::TypeLegal)
65316437
return false;
65326438

6439+
// Make sure that the I->getType() is a struct type with two elements.
6440+
if (!I->getType()->isStructTy() || I->getType()->getStructNumElements() != 2)
6441+
return false;
6442+
65336443
I->getParent()->setName("overflow.res");
6534-
auto *OverflowResBB = I->getParent();
6535-
auto *OverflowoEntryBB =
6444+
BasicBlock *OverflowResBB = I->getParent();
6445+
BasicBlock *OverflowoEntryBB =
65366446
I->getParent()->splitBasicBlock(I, "overflow.entry", /*Before*/ true);
65376447
BasicBlock *NoOverflowBB = BasicBlock::Create(
65386448
I->getContext(), "overflow.no", I->getFunction(), OverflowResBB);
65396449
BasicBlock *OverflowBB = BasicBlock::Create(I->getContext(), "overflow",
65406450
I->getFunction(), OverflowResBB);
65416451
// new blocks should be:
65426452
// entry:
6543-
// (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
6453+
// if signed:
6454+
// (lhs_lo ^ lhs_hi) || (rhs_lo ^ rhs_hi) ? overflow, overflow_no
6455+
// else:
6456+
// (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
65446457

65456458
// overflow_no:
65466459
// overflow:
65476460
// overflow.res:
65486461

6549-
//------------------------------------------------------------------------------
6462+
// ----------------------------
65506463
// BB overflow.entry:
6551-
// get Lo and Hi of RHS & LHS:
6464+
// get Lo and Hi of LHS & RHS:
65526465
IRBuilder<> Builder(OverflowoEntryBB->getTerminator());
6553-
auto *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
6554-
auto *SignLoRHS =
6555-
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
6556-
auto *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
6466+
Value *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
6467+
Value *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
6468+
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
6469+
Value *LoRHS = Builder.CreateTrunc(RHS, LegalTy, "lo.rhs");
6470+
Value *HiRHS = Builder.CreateLShr(RHS, VTHalfBitWidth, "rhs.lsr");
65576471
HiRHS = Builder.CreateTrunc(HiRHS, LegalTy, "hi.rhs");
65586472

6559-
auto *LoLHS = Builder.CreateTrunc(LHS, LegalTy, "lo.lhs");
6560-
auto *SignLoLHS =
6561-
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
6562-
auto *HiLHS = Builder.CreateLShr(LHS, VTHalfBitWidth, "lhs.lsr");
6563-
HiLHS = Builder.CreateTrunc(HiLHS, LegalTy, "hi.lhs");
6564-
// xor(HiLHS, SignLoLHS) false -> no overflow
6565-
// xor(HiRHS, SignLoRHS) false -> no overflow
6566-
// if either of the above is true, then overflow.
6567-
// auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
6568-
auto *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
6569-
auto *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
6570-
// auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
6571-
auto *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
6572-
auto *Cmp = Builder.CreateCmp(ICmpInst::ICMP_EQ, Or,
6573-
ConstantInt::get(Or->getType(), 1));
6574-
Builder.CreateCondBr(Cmp, OverflowBB, NoOverflowBB);
6473+
Value *IsAnyBitTrue;
6474+
if (IsSigned) {
6475+
Value *SignLoLHS =
6476+
Builder.CreateAShr(LoLHS, VTHalfBitWidth - 1, "sign.lo.lhs");
6477+
Value *SignLoRHS =
6478+
Builder.CreateAShr(LoRHS, VTHalfBitWidth - 1, "sign.lo.rhs");
6479+
Value *XorLHS = Builder.CreateXor(HiLHS, SignLoLHS);
6480+
Value *XorRHS = Builder.CreateXor(HiRHS, SignLoRHS);
6481+
Value *Or = Builder.CreateOr(XorLHS, XorRHS, "or.lhs.rhs");
6482+
IsAnyBitTrue = Builder.CreateCmp(ICmpInst::ICMP_EQ, Or,
6483+
ConstantInt::get(Or->getType(), 1));
6484+
} else {
6485+
Value *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS,
6486+
ConstantInt::getNullValue(LegalTy));
6487+
Value *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS,
6488+
ConstantInt::getNullValue(LegalTy));
6489+
IsAnyBitTrue = Builder.CreateOr(CmpLHS, CmpRHS, "or.lhs.rhs");
6490+
}
6491+
6492+
Builder.CreateCondBr(IsAnyBitTrue, OverflowBB, NoOverflowBB);
65756493
OverflowoEntryBB->getTerminator()->eraseFromParent();
65766494

6577-
//------------------------------------------------------------------------------
65786495
// BB overflow.no:
65796496
Builder.SetInsertPoint(NoOverflowBB);
6580-
auto *ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
6581-
auto *ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
6582-
auto *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
6497+
Value *ExtLoLHS, *ExtLoRHS;
6498+
if (IsSigned) {
6499+
ExtLoLHS = Builder.CreateSExt(LoLHS, Ty, "lo.lhs.ext");
6500+
ExtLoRHS = Builder.CreateSExt(LoRHS, Ty, "lo.rhs.ext");
6501+
} else {
6502+
ExtLoLHS = Builder.CreateZExt(LoLHS, Ty, "lo.lhs.ext");
6503+
ExtLoRHS = Builder.CreateZExt(LoRHS, Ty, "lo.rhs.ext");
6504+
}
6505+
6506+
Value *Mul = Builder.CreateMul(ExtLoLHS, ExtLoRHS, "mul.no.overflow");
65836507
Builder.CreateBr(OverflowResBB);
65846508

6585-
//------------------------------------------------------------------------------
65866509
// BB overflow.res:
65876510
Builder.SetInsertPoint(OverflowResBB, OverflowResBB->getFirstInsertionPt());
6588-
auto *PHINode1 = Builder.CreatePHI(Ty, 2);
6511+
PHINode *PHINode1 = Builder.CreatePHI(Ty, 2);
65896512
PHINode1->addIncoming(Mul, NoOverflowBB);
6590-
auto *PHINode2 =
6513+
PHINode *PHINode2 =
65916514
Builder.CreatePHI(IntegerType::getInt1Ty(I->getContext()), 2);
65926515
PHINode2->addIncoming(ConstantInt::getFalse(I->getContext()), NoOverflowBB);
65936516

@@ -6606,16 +6529,16 @@ bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
66066529
// BB overflow:
66076530
I->insertInto(OverflowBB, OverflowBB->end());
66086531
Builder.SetInsertPoint(OverflowBB, OverflowBB->end());
6609-
auto *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
6610-
auto *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
6532+
Value *MulOverflow = Builder.CreateExtractValue(I, {0}, "mul.overflow");
6533+
Value *OverflowFlag = Builder.CreateExtractValue(I, {1}, "overflow.flag");
66116534
Builder.CreateBr(OverflowResBB);
66126535

66136536
// Add The Extracted values to the PHINodes in the overflow.res block.
66146537
PHINode1->addIncoming(MulOverflow, OverflowBB);
66156538
PHINode2->addIncoming(OverflowFlag, OverflowBB);
66166539

6617-
// return false to stop reprocessing the function.
6618-
return false;
6540+
ModifiedDT = ModifyDT::ModifyBBDT;
6541+
return true;
66196542
}
66206543

66216544
/// If there are any memory operands, use OptimizeMemoryInst to sink their

0 commit comments

Comments
 (0)