@@ -338,6 +338,10 @@ class CodeGenPrepare {
338
338
// / Keep track of instructions removed during promotion.
339
339
SetOfInstrs RemovedInsts;
340
340
341
+ // / Keep track of seen mul_with_overflow intrinsics to avoid
342
+ // reprocessing them.
343
+ DenseMap<Instruction *, bool > SeenMulWithOverflowInstrs;
344
+
341
345
// / Keep track of sext chains based on their initial value.
342
346
DenseMap<Value *, Instruction *> SeenChainsForSExt;
343
347
@@ -433,6 +437,8 @@ class CodeGenPrepare {
433
437
bool optimizeGatherScatterInst (Instruction *MemoryInst, Value *Ptr);
434
438
bool optimizeUMulWithOverflow (Instruction *I);
435
439
bool optimizeSMulWithOverflow (Instruction *I);
440
+ bool optimizeMulWithOverflow (Instruction *I, bool IsSigned,
441
+ ModifyDT &ModifiedDT);
436
442
bool optimizeInlineAsmInst (CallInst *CS);
437
443
bool optimizeCallInst (CallInst *CI, ModifyDT &ModifiedDT);
438
444
bool optimizeExt (Instruction *&I);
@@ -774,6 +780,7 @@ bool CodeGenPrepare::_run(Function &F) {
774
780
verifyBFIUpdates (F);
775
781
#endif
776
782
783
+ SeenMulWithOverflowInstrs.clear ();
777
784
return EverMadeChange;
778
785
}
779
786
@@ -2781,9 +2788,9 @@ bool CodeGenPrepare::optimizeCallInst(CallInst *CI, ModifyDT &ModifiedDT) {
2781
2788
}
2782
2789
return false ;
2783
2790
case Intrinsic::umul_with_overflow:
2784
- return optimizeUMulWithOverflow (II);
2791
+ return optimizeMulWithOverflow (II, /* IsSigned= */ false , ModifiedDT );
2785
2792
case Intrinsic::smul_with_overflow:
2786
- return optimizeSMulWithOverflow (II);
2793
+ return optimizeMulWithOverflow (II, /* IsSigned= */ true , ModifiedDT );
2787
2794
}
2788
2795
2789
2796
SmallVector<Value *, 2 > PtrOps;
@@ -6395,122 +6402,20 @@ bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
6395
6402
return true ;
6396
6403
}
6397
6404
6398
- // Rewrite the umul_with_overflow intrinsic by checking if both of the
6405
+ // Rewrite the mul_with_overflow intrinsic by checking if both of the
6399
6406
// operands' value range is within the legal type. If so, we can optimize the
6400
6407
// multiplication algorithm. This code is supposed to be written during the step
6401
6408
// of type legalization, but given that we need to reconstruct the IR which is
6402
6409
// not doable there, we do it here.
6403
- bool CodeGenPrepare::optimizeUMulWithOverflow (Instruction *I) {
6410
+ bool CodeGenPrepare::optimizeMulWithOverflow (Instruction *I, bool IsSigned,
6411
+ ModifyDT &ModifiedDT) {
6404
6412
// Enable this optimization only for aarch64.
6405
6413
if (!TLI->getTargetMachine ().getTargetTriple ().isAArch64 ())
6406
6414
return false ;
6407
- if (TLI->getTypeAction (
6408
- I->getContext (),
6409
- TLI->getValueType (*DL, I->getType ()->getContainedType (0 ))) !=
6410
- TargetLowering::TypeExpandInteger)
6411
- return false ;
6412
-
6413
- Value *LHS = I->getOperand (0 );
6414
- Value *RHS = I->getOperand (1 );
6415
- auto *Ty = LHS->getType ();
6416
- unsigned VTBitWidth = Ty->getScalarSizeInBits ();
6417
- unsigned VTHalfBitWidth = VTBitWidth / 2 ;
6418
- auto *LegalTy = IntegerType::getIntNTy (I->getContext (), VTHalfBitWidth);
6419
-
6420
- // Skip the optimization if the type with HalfBitWidth is not legal for the
6421
- // target.
6422
- if (TLI->getTypeAction (I->getContext (), TLI->getValueType (*DL, LegalTy)) !=
6423
- TargetLowering::TypeLegal)
6415
+ // If we have already seen this instruction, don't process it again.
6416
+ if (!SeenMulWithOverflowInstrs.insert (std::make_pair (I, true )).second )
6424
6417
return false ;
6425
6418
6426
- I->getParent ()->setName (" overflow.res" );
6427
- auto *OverflowResBB = I->getParent ();
6428
- auto *OverflowoEntryBB =
6429
- I->getParent ()->splitBasicBlock (I, " overflow.entry" , /* Before*/ true );
6430
- BasicBlock *NoOverflowBB = BasicBlock::Create (
6431
- I->getContext (), " overflow.no" , I->getFunction (), OverflowResBB);
6432
- BasicBlock *OverflowBB = BasicBlock::Create (I->getContext (), " overflow" ,
6433
- I->getFunction (), OverflowResBB);
6434
- // new blocks should be:
6435
- // entry:
6436
- // (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
6437
-
6438
- // overflow_no:
6439
- // overflow:
6440
- // overflow.res:
6441
- // ------------------------------------------------------------------------------
6442
- // BB overflow.entry:
6443
- // get Lo and Hi of RHS & LHS:
6444
- IRBuilder<> Builder (OverflowoEntryBB->getTerminator ());
6445
- auto *LoRHS = Builder.CreateTrunc (RHS, LegalTy, " lo.rhs.trunc" );
6446
- auto *ShrHiRHS = Builder.CreateLShr (RHS, VTHalfBitWidth, " rhs.lsr" );
6447
- auto *HiRHS = Builder.CreateTrunc (ShrHiRHS, LegalTy, " hi.rhs.trunc" );
6448
-
6449
- auto *LoLHS = Builder.CreateTrunc (LHS, LegalTy, " lo.lhs.trunc" );
6450
- auto *ShrHiLHS = Builder.CreateLShr (LHS, VTHalfBitWidth, " lhs.lsr" );
6451
- auto *HiLHS = Builder.CreateTrunc (ShrHiLHS, LegalTy, " hi.lhs.trunc" );
6452
-
6453
- auto *CmpLHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiLHS,
6454
- ConstantInt::getNullValue (LegalTy));
6455
- auto *CmpRHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiRHS,
6456
- ConstantInt::getNullValue (LegalTy));
6457
- auto *Or = Builder.CreateOr (CmpLHS, CmpRHS, " or.lhs.rhs" );
6458
- Builder.CreateCondBr (Or, OverflowBB, NoOverflowBB);
6459
- OverflowoEntryBB->getTerminator ()->eraseFromParent ();
6460
-
6461
- // ------------------------------------------------------------------------------
6462
- // BB overflow.no:
6463
- Builder.SetInsertPoint (NoOverflowBB);
6464
- auto *ExtLoLHS = Builder.CreateZExt (LoLHS, Ty, " lo.lhs.ext" );
6465
- auto *ExtLoRHS = Builder.CreateZExt (LoRHS, Ty, " lo.rhs.ext" );
6466
- auto *Mul = Builder.CreateMul (ExtLoLHS, ExtLoRHS, " mul.no.overflow" );
6467
- Builder.CreateBr (OverflowResBB);
6468
-
6469
- // ------------------------------------------------------------------------------
6470
- // BB overflow.res:
6471
- Builder.SetInsertPoint (OverflowResBB, OverflowResBB->getFirstInsertionPt ());
6472
- auto *PHINode1 = Builder.CreatePHI (Ty, 2 );
6473
- PHINode1->addIncoming (Mul, NoOverflowBB);
6474
- auto *PHINode2 =
6475
- Builder.CreatePHI (IntegerType::getInt1Ty (I->getContext ()), 2 );
6476
- PHINode2->addIncoming (ConstantInt::getFalse (I->getContext ()), NoOverflowBB);
6477
-
6478
- StructType *STy = StructType::get (
6479
- I->getContext (), {Ty, IntegerType::getInt1Ty (I->getContext ())});
6480
- Value *StructValOverflowRes = PoisonValue::get (STy);
6481
- StructValOverflowRes =
6482
- Builder.CreateInsertValue (StructValOverflowRes, PHINode1, {0 });
6483
- StructValOverflowRes =
6484
- Builder.CreateInsertValue (StructValOverflowRes, PHINode2, {1 });
6485
- // Before moving the mul.overflow intrinsic to the overflowBB, replace all its
6486
- // uses by StructValOverflowRes.
6487
- I->replaceAllUsesWith (StructValOverflowRes);
6488
- I->removeFromParent ();
6489
-
6490
- // BB overflow:
6491
- I->insertInto (OverflowBB, OverflowBB->end ());
6492
- Builder.SetInsertPoint (OverflowBB, OverflowBB->end ());
6493
- auto *MulOverflow = Builder.CreateExtractValue (I, {0 }, " mul.overflow" );
6494
- auto *OverflowFlag = Builder.CreateExtractValue (I, {1 }, " overflow.flag" );
6495
- Builder.CreateBr (OverflowResBB);
6496
-
6497
- // Add The Extracted values to the PHINodes in the overflow.res block.
6498
- PHINode1->addIncoming (MulOverflow, OverflowBB);
6499
- PHINode2->addIncoming (OverflowFlag, OverflowBB);
6500
-
6501
- // return false to stop reprocessing the function.
6502
- return false ;
6503
- }
6504
-
6505
- // Rewrite the smul_with_overflow intrinsic by checking if both of the
6506
- // operands' value range is within the legal type. If so, we can optimize the
6507
- // multiplication algorithm. This code is supposed to be written during the step
6508
- // of type legalization, but given that we need to reconstruct the IR which is
6509
- // not doable there, we do it here.
6510
- bool CodeGenPrepare::optimizeSMulWithOverflow (Instruction *I) {
6511
- // Enable this optimization only for aarch64.
6512
- if (!TLI->getTargetMachine ().getTargetTriple ().isAArch64 ())
6513
- return false ;
6514
6419
if (TLI->getTypeAction (
6515
6420
I->getContext (),
6516
6421
TLI->getValueType (*DL, I->getType ()->getContainedType (0 ))) !=
@@ -6519,75 +6424,93 @@ bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
6519
6424
6520
6425
Value *LHS = I->getOperand (0 );
6521
6426
Value *RHS = I->getOperand (1 );
6522
- auto *Ty = LHS->getType ();
6427
+ Type *Ty = LHS->getType ();
6523
6428
unsigned VTBitWidth = Ty->getScalarSizeInBits ();
6524
6429
unsigned VTHalfBitWidth = VTBitWidth / 2 ;
6525
- auto *LegalTy = IntegerType::getIntNTy (I->getContext (), VTHalfBitWidth);
6430
+ IntegerType *LegalTy =
6431
+ IntegerType::getIntNTy (I->getContext (), VTHalfBitWidth);
6526
6432
6527
6433
// Skip the optimization if the type with HalfBitWidth is not legal for the
6528
6434
// target.
6529
6435
if (TLI->getTypeAction (I->getContext (), TLI->getValueType (*DL, LegalTy)) !=
6530
6436
TargetLowering::TypeLegal)
6531
6437
return false ;
6532
6438
6439
+ // Make sure that the I->getType() is a struct type with two elements.
6440
+ if (!I->getType ()->isStructTy () || I->getType ()->getStructNumElements () != 2 )
6441
+ return false ;
6442
+
6533
6443
I->getParent ()->setName (" overflow.res" );
6534
- auto *OverflowResBB = I->getParent ();
6535
- auto *OverflowoEntryBB =
6444
+ BasicBlock *OverflowResBB = I->getParent ();
6445
+ BasicBlock *OverflowoEntryBB =
6536
6446
I->getParent ()->splitBasicBlock (I, " overflow.entry" , /* Before*/ true );
6537
6447
BasicBlock *NoOverflowBB = BasicBlock::Create (
6538
6448
I->getContext (), " overflow.no" , I->getFunction (), OverflowResBB);
6539
6449
BasicBlock *OverflowBB = BasicBlock::Create (I->getContext (), " overflow" ,
6540
6450
I->getFunction (), OverflowResBB);
6541
6451
// new blocks should be:
6542
6452
// entry:
6543
- // (lhs_lo ne lhs_hi) || (rhs_lo ne rhs_hi) ? overflow, overflow_no
6453
+ // if signed:
6454
+ // (lhs_lo ^ lhs_hi) || (rhs_lo ^ rhs_hi) ? overflow, overflow_no
6455
+ // else:
6456
+ // (lhs_hi != 0) || (rhs_hi != 0) ? overflow, overflow_no
6544
6457
6545
6458
// overflow_no:
6546
6459
// overflow:
6547
6460
// overflow.res:
6548
6461
6549
- // -------------------------------------------------- ----------------------------
6462
+ // ----------------------------
6550
6463
// BB overflow.entry:
6551
- // get Lo and Hi of RHS & LHS :
6464
+ // get Lo and Hi of LHS & RHS :
6552
6465
IRBuilder<> Builder (OverflowoEntryBB->getTerminator ());
6553
- auto *LoRHS = Builder.CreateTrunc (RHS, LegalTy, " lo.rhs" );
6554
- auto *SignLoRHS =
6555
- Builder.CreateAShr (LoRHS, VTHalfBitWidth - 1 , " sign.lo.rhs" );
6556
- auto *HiRHS = Builder.CreateLShr (RHS, VTHalfBitWidth, " rhs.lsr" );
6466
+ Value *LoLHS = Builder.CreateTrunc (LHS, LegalTy, " lo.lhs" );
6467
+ Value *HiLHS = Builder.CreateLShr (LHS, VTHalfBitWidth, " lhs.lsr" );
6468
+ HiLHS = Builder.CreateTrunc (HiLHS, LegalTy, " hi.lhs" );
6469
+ Value *LoRHS = Builder.CreateTrunc (RHS, LegalTy, " lo.rhs" );
6470
+ Value *HiRHS = Builder.CreateLShr (RHS, VTHalfBitWidth, " rhs.lsr" );
6557
6471
HiRHS = Builder.CreateTrunc (HiRHS, LegalTy, " hi.rhs" );
6558
6472
6559
- auto *LoLHS = Builder.CreateTrunc (LHS, LegalTy, " lo.lhs" );
6560
- auto *SignLoLHS =
6561
- Builder.CreateAShr (LoLHS, VTHalfBitWidth - 1 , " sign.lo.lhs" );
6562
- auto *HiLHS = Builder.CreateLShr (LHS, VTHalfBitWidth, " lhs.lsr" );
6563
- HiLHS = Builder.CreateTrunc (HiLHS, LegalTy, " hi.lhs" );
6564
- // xor(HiLHS, SignLoLHS) false -> no overflow
6565
- // xor(HiRHS, SignLoRHS) false -> no overflow
6566
- // if either of the above is true, then overflow.
6567
- // auto *CmpLHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiLHS, SignLoLHS);
6568
- auto *XorLHS = Builder.CreateXor (HiLHS, SignLoLHS);
6569
- auto *XorRHS = Builder.CreateXor (HiRHS, SignLoRHS);
6570
- // auto *CmpRHS = Builder.CreateCmp(ICmpInst::ICMP_NE, HiRHS, SignLoRHS);
6571
- auto *Or = Builder.CreateOr (XorLHS, XorRHS, " or.lhs.rhs" );
6572
- auto *Cmp = Builder.CreateCmp (ICmpInst::ICMP_EQ, Or,
6573
- ConstantInt::get (Or->getType (), 1 ));
6574
- Builder.CreateCondBr (Cmp, OverflowBB, NoOverflowBB);
6473
+ Value *IsAnyBitTrue;
6474
+ if (IsSigned) {
6475
+ Value *SignLoLHS =
6476
+ Builder.CreateAShr (LoLHS, VTHalfBitWidth - 1 , " sign.lo.lhs" );
6477
+ Value *SignLoRHS =
6478
+ Builder.CreateAShr (LoRHS, VTHalfBitWidth - 1 , " sign.lo.rhs" );
6479
+ Value *XorLHS = Builder.CreateXor (HiLHS, SignLoLHS);
6480
+ Value *XorRHS = Builder.CreateXor (HiRHS, SignLoRHS);
6481
+ Value *Or = Builder.CreateOr (XorLHS, XorRHS, " or.lhs.rhs" );
6482
+ IsAnyBitTrue = Builder.CreateCmp (ICmpInst::ICMP_EQ, Or,
6483
+ ConstantInt::get (Or->getType (), 1 ));
6484
+ } else {
6485
+ Value *CmpLHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiLHS,
6486
+ ConstantInt::getNullValue (LegalTy));
6487
+ Value *CmpRHS = Builder.CreateCmp (ICmpInst::ICMP_NE, HiRHS,
6488
+ ConstantInt::getNullValue (LegalTy));
6489
+ IsAnyBitTrue = Builder.CreateOr (CmpLHS, CmpRHS, " or.lhs.rhs" );
6490
+ }
6491
+
6492
+ Builder.CreateCondBr (IsAnyBitTrue, OverflowBB, NoOverflowBB);
6575
6493
OverflowoEntryBB->getTerminator ()->eraseFromParent ();
6576
6494
6577
- // ------------------------------------------------------------------------------
6578
6495
// BB overflow.no:
6579
6496
Builder.SetInsertPoint (NoOverflowBB);
6580
- auto *ExtLoLHS = Builder.CreateSExt (LoLHS, Ty, " lo.lhs.ext" );
6581
- auto *ExtLoRHS = Builder.CreateSExt (LoRHS, Ty, " lo.rhs.ext" );
6582
- auto *Mul = Builder.CreateMul (ExtLoLHS, ExtLoRHS, " mul.no.overflow" );
6497
+ Value *ExtLoLHS, *ExtLoRHS;
6498
+ if (IsSigned) {
6499
+ ExtLoLHS = Builder.CreateSExt (LoLHS, Ty, " lo.lhs.ext" );
6500
+ ExtLoRHS = Builder.CreateSExt (LoRHS, Ty, " lo.rhs.ext" );
6501
+ } else {
6502
+ ExtLoLHS = Builder.CreateZExt (LoLHS, Ty, " lo.lhs.ext" );
6503
+ ExtLoRHS = Builder.CreateZExt (LoRHS, Ty, " lo.rhs.ext" );
6504
+ }
6505
+
6506
+ Value *Mul = Builder.CreateMul (ExtLoLHS, ExtLoRHS, " mul.no.overflow" );
6583
6507
Builder.CreateBr (OverflowResBB);
6584
6508
6585
- // ------------------------------------------------------------------------------
6586
6509
// BB overflow.res:
6587
6510
Builder.SetInsertPoint (OverflowResBB, OverflowResBB->getFirstInsertionPt ());
6588
- auto *PHINode1 = Builder.CreatePHI (Ty, 2 );
6511
+ PHINode *PHINode1 = Builder.CreatePHI (Ty, 2 );
6589
6512
PHINode1->addIncoming (Mul, NoOverflowBB);
6590
- auto *PHINode2 =
6513
+ PHINode *PHINode2 =
6591
6514
Builder.CreatePHI (IntegerType::getInt1Ty (I->getContext ()), 2 );
6592
6515
PHINode2->addIncoming (ConstantInt::getFalse (I->getContext ()), NoOverflowBB);
6593
6516
@@ -6606,16 +6529,16 @@ bool CodeGenPrepare::optimizeSMulWithOverflow(Instruction *I) {
6606
6529
// BB overflow:
6607
6530
I->insertInto (OverflowBB, OverflowBB->end ());
6608
6531
Builder.SetInsertPoint (OverflowBB, OverflowBB->end ());
6609
- auto *MulOverflow = Builder.CreateExtractValue (I, {0 }, " mul.overflow" );
6610
- auto *OverflowFlag = Builder.CreateExtractValue (I, {1 }, " overflow.flag" );
6532
+ Value *MulOverflow = Builder.CreateExtractValue (I, {0 }, " mul.overflow" );
6533
+ Value *OverflowFlag = Builder.CreateExtractValue (I, {1 }, " overflow.flag" );
6611
6534
Builder.CreateBr (OverflowResBB);
6612
6535
6613
6536
// Add The Extracted values to the PHINodes in the overflow.res block.
6614
6537
PHINode1->addIncoming (MulOverflow, OverflowBB);
6615
6538
PHINode2->addIncoming (OverflowFlag, OverflowBB);
6616
6539
6617
- // return false to stop reprocessing the function.
6618
- return false ;
6540
+ ModifiedDT = ModifyDT::ModifyBBDT;
6541
+ return true ;
6619
6542
}
6620
6543
6621
6544
// / If there are any memory operands, use OptimizeMemoryInst to sink their
0 commit comments