@@ -431,6 +431,8 @@ class CodeGenPrepare {
431431 bool optimizeMemoryInst (Instruction *MemoryInst, Value *Addr, Type *AccessTy,
432432 unsigned AddrSpace);
433433 bool optimizeGatherScatterInst (Instruction *MemoryInst, Value *Ptr);
434+ bool optimizeMulWithOverflow (Instruction *I, bool IsSigned,
435+ ModifyDT &ModifiedDT);
434436 bool optimizeInlineAsmInst (CallInst *CS);
435437 bool optimizeCallInst (CallInst *CI, ModifyDT &ModifiedDT);
436438 bool optimizeExt (Instruction *&I);
@@ -2797,6 +2799,10 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
27972799 }
27982800 }
27992801 return false ;
2802+ case Intrinsic::umul_with_overflow:
2803+ return optimizeMulWithOverflow (II, /* IsSigned=*/ false , ModifiedDT);
2804+ case Intrinsic::smul_with_overflow:
2805+ return optimizeMulWithOverflow (II, /* IsSigned=*/ true , ModifiedDT);
28002806 }
28012807
28022808 SmallVector<Value *, 2 > PtrOps;
@@ -6391,6 +6397,182 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
63916397 return true ;
63926398}
63936399
6400+ // This is a helper for CodeGenPrepare::optimizeMulWithOverflow.
6401+ // Check the pattern we are interested in where there are maximum 2 uses
6402+ // of the intrinsic which are the extract instructions.
6403+ static bool matchOverflowPattern (Instruction *&I, ExtractValueInst *&MulExtract,
6404+ ExtractValueInst *&OverflowExtract) {
6405+ // Bail out if it's more than 2 users:
6406+ if (I->hasNUsesOrMore (3 ))
6407+ return false ;
6408+
6409+ for (User *U : I->users ()) {
6410+ auto *Extract = dyn_cast<ExtractValueInst>(U);
6411+ if (!Extract || Extract->getNumIndices () != 1 )
6412+ return false ;
6413+
6414+ unsigned Index = Extract->getIndices ()[0 ];
6415+ if (Index == 0 )
6416+ MulExtract = Extract;
6417+ else if (Index == 1 )
6418+ OverflowExtract = Extract;
6419+ else
6420+ return false ;
6421+ }
6422+ return true ;
6423+ }
6424+
6425+ // Rewrite the mul_with_overflow intrinsic by checking if both of the
6426+ // operands' value ranges are within the legal type. If so, we can optimize the
6427+ // multiplication algorithm. This code is supposed to be written during the step
6428+ // of type legalization, but given that we need to reconstruct the IR which is
6429+ // not doable there, we do it here.
6430+ // The IR after the optimization will look like:
6431+ // entry:
6432+ // if signed:
6433+ // ( (lhs_lo>>BW-1) ^ lhs_hi) || ( (rhs_lo>>BW-1) ^ rhs_hi) ? overflow,
6434+ // overflow_no
6435+ // else:
6436+ // (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
6437+ // overflow_no:
6438+ // overflow:
6439+ // overflow.res:
6440+ // \returns true if optimization was applied
6441+ // TODO: This optimization can be further improved to optimize branching on
6442+ // overflow where the 'overflow_no' BB can branch directly to the false
6443+ // successor of overflow, but that would add additional complexity so we leave
6444+ // it for future work.
6445+ bool CodeGenPrepare::optimizeMulWithOverflow (Instruction *I, bool IsSigned,
6446+ ModifyDT &ModifiedDT) {
6447+ // Check if target supports this optimization.
6448+ if (!TLI->shouldOptimizeMulOverflowWithZeroHighBits (
6449+ I->getContext (),
6450+ TLI->getValueType (*DL, I->getType ()->getContainedType (0 ))))
6451+ return false ;
6452+
6453+ ExtractValueInst *MulExtract = nullptr , *OverflowExtract = nullptr ;
6454+ if (!matchOverflowPattern (I, MulExtract, OverflowExtract))
6455+ return false ;
6456+
6457+ // Keep track of the instruction to stop reoptimizing it again.
6458+ InsertedInsts.insert (I);
6459+
6460+ Value *LHS = I->getOperand (0 );
6461+ Value *RHS = I->getOperand (1 );
6462+ Type *Ty = LHS->getType ();
6463+ unsigned VTHalfBitWidth = Ty->getScalarSizeInBits () / 2 ;
6464+ Type *LegalTy = Ty->getWithNewBitWidth (VTHalfBitWidth);
6465+
6466+ // New BBs:
6467+ BasicBlock *OverflowEntryBB =
6468+ I->getParent ()->splitBasicBlock (I, " " , /* Before*/ true );
6469+ OverflowEntryBB->takeName (I->getParent ());
6470+ // Keep the 'br' instruction that is generated as a result of the split to be
6471+ // erased/replaced later.
6472+ Instruction *OldTerminator = OverflowEntryBB->getTerminator ();
6473+ BasicBlock *NoOverflowBB =
6474+ BasicBlock::Create (I->getContext (), " overflow.no" , I->getFunction ());
6475+ NoOverflowBB->moveAfter (OverflowEntryBB);
6476+ BasicBlock *OverflowBB =
6477+ BasicBlock::Create (I->getContext (), " overflow" , I->getFunction ());
6478+ OverflowBB->moveAfter (NoOverflowBB);
6479+
6480+ // BB overflow.entry:
6481+ IRBuilder<> Builder (OverflowEntryBB);
6482+ // Extract low and high halves of LHS:
6483+ Value *LoLHS = Builder.CreateTrunc (LHS, LegalTy, " lo.lhs" );
6484+ Value *HiLHS = Builder.CreateLShr (LHS, VTHalfBitWidth, " lhs.lsr" );
6485+ HiLHS = Builder.CreateTrunc (HiLHS, LegalTy, " hi.lhs" );
6486+
6487+ // Extract low and high halves of RHS:
6488+ Value *LoRHS = Builder.CreateTrunc (RHS, LegalTy, " lo.rhs" );
6489+ Value *HiRHS = Builder.CreateLShr (RHS, VTHalfBitWidth, " rhs.lsr" );
6490+ HiRHS = Builder.CreateTrunc (HiRHS, LegalTy, " hi.rhs" );
6491+
6492+ Value *IsAnyBitTrue;
6493+ if (IsSigned) {
6494+ Value *SignLoLHS =
6495+ Builder.CreateAShr (LoLHS, VTHalfBitWidth - 1 , " sign.lo.lhs" );
6496+ Value *SignLoRHS =
6497+ Builder.CreateAShr (LoRHS, VTHalfBitWidth - 1 , " sign.lo.rhs" );
6498+ Value *XorLHS = Builder.CreateXor (HiLHS, SignLoLHS);
6499+ Value *XorRHS = Builder.CreateXor (HiRHS, SignLoRHS);
6500+ Value *Or = Builder.CreateOr (XorLHS, XorRHS, " or.lhs.rhs" );
6501+ IsAnyBitTrue = Builder.CreateCmp (ICmpInst::ICMP_NE, Or,
6502+ ConstantInt::getNullValue (Or->getType ()));
6503+ } else {
6504+ Value *CmpLHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiLHS,
6505+ ConstantInt::getNullValue (LegalTy));
6506+ Value *CmpRHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiRHS,
6507+ ConstantInt::getNullValue (LegalTy));
6508+ IsAnyBitTrue = Builder.CreateOr (CmpLHS, CmpRHS, " or.lhs.rhs" );
6509+ }
6510+ Builder.CreateCondBr (IsAnyBitTrue, OverflowBB, NoOverflowBB);
6511+
6512+ // BB overflow.no:
6513+ Builder.SetInsertPoint (NoOverflowBB);
6514+ Value *ExtLoLHS, *ExtLoRHS;
6515+ if (IsSigned) {
6516+ ExtLoLHS = Builder.CreateSExt (LoLHS, Ty, " lo.lhs.ext" );
6517+ ExtLoRHS = Builder.CreateSExt (LoRHS, Ty, " lo.rhs.ext" );
6518+ } else {
6519+ ExtLoLHS = Builder.CreateZExt (LoLHS, Ty, " lo.lhs.ext" );
6520+ ExtLoRHS = Builder.CreateZExt (LoRHS, Ty, " lo.rhs.ext" );
6521+ }
6522+
6523+ Value *Mul = Builder.CreateMul (ExtLoLHS, ExtLoRHS, " mul.overflow.no" );
6524+
6525+ // Create the 'overflow.res' BB to merge the results of
6526+ // the two paths:
6527+ BasicBlock *OverflowResBB = I->getParent ();
6528+ OverflowResBB->setName (" overflow.res" );
6529+
6530+ // BB overflow.no: jump to overflow.res BB
6531+ Builder.CreateBr (OverflowResBB);
6532+ // No we don't need the old terminator in overflow.entry BB, erase it:
6533+ OldTerminator->eraseFromParent ();
6534+
6535+ // BB overflow.res:
6536+ Builder.SetInsertPoint (OverflowResBB, OverflowResBB->getFirstInsertionPt ());
6537+ // Create PHI nodes to merge results from no.overflow BB and overflow BB to
6538+ // replace the extract instructions.
6539+ PHINode *OverflowResPHI = Builder.CreatePHI (Ty, 2 ),
6540+ *OverflowFlagPHI =
6541+ Builder.CreatePHI (IntegerType::getInt1Ty (I->getContext ()), 2 );
6542+
6543+ // Add the incoming values from no.overflow BB and later from overflow BB.
6544+ OverflowResPHI->addIncoming (Mul, NoOverflowBB);
6545+ OverflowFlagPHI->addIncoming (ConstantInt::getFalse (I->getContext ()),
6546+ NoOverflowBB);
6547+
6548+ // Replace all users of MulExtract and OverflowExtract to use the PHI nodes.
6549+ if (MulExtract) {
6550+ MulExtract->replaceAllUsesWith (OverflowResPHI);
6551+ MulExtract->eraseFromParent ();
6552+ }
6553+ if (OverflowExtract) {
6554+ OverflowExtract->replaceAllUsesWith (OverflowFlagPHI);
6555+ OverflowExtract->eraseFromParent ();
6556+ }
6557+
6558+ // Remove the intrinsic from parent (overflow.res BB) as it will be part of
6559+ // overflow BB
6560+ I->removeFromParent ();
6561+ // BB overflow:
6562+ I->insertInto (OverflowBB, OverflowBB->end ());
6563+ Builder.SetInsertPoint (OverflowBB, OverflowBB->end ());
6564+ Value *MulOverflow = Builder.CreateExtractValue (I, {0 }, " mul.overflow" );
6565+ Value *OverflowFlag = Builder.CreateExtractValue (I, {1 }, " overflow.flag" );
6566+ Builder.CreateBr (OverflowResBB);
6567+
6568+ // Add The Extracted values to the PHINodes in the overflow.res BB.
6569+ OverflowResPHI->addIncoming (MulOverflow, OverflowBB);
6570+ OverflowFlagPHI->addIncoming (OverflowFlag, OverflowBB);
6571+
6572+ ModifiedDT = ModifyDT::ModifyBBDT;
6573+ return true ;
6574+ }
6575+
63946576// / If there are any memory operands, use OptimizeMemoryInst to sink their
63956577// / address computing into the block when possible / profitable.
63966578bool CodeGenPrepare::optimizeInlineAsmInst (CallInst *CS) {
0 commit comments