@@ -1381,48 +1381,116 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS,
13811381 };
13821382}
13831383
1384+ namespace {
1385+
1386+ // Represents a interesting integer binary operation for
1387+ // getExtendedOperandRecurrence. This may be a shl that is being treated as a
1388+ // multiply or a 'or disjoint' that is being treated as 'add nsw nuw'.
1389+ struct BinaryOp {
1390+ unsigned Opcode;
1391+ std::array<Value *, 2 > Operands;
1392+ bool IsNSW = false ;
1393+ bool IsNUW = false ;
1394+
1395+ explicit BinaryOp (Instruction *Op)
1396+ : Opcode(Op->getOpcode ()),
1397+ Operands({Op->getOperand (0 ), Op->getOperand (1 )}) {
1398+ if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {
1399+ IsNSW = OBO->hasNoSignedWrap ();
1400+ IsNUW = OBO->hasNoUnsignedWrap ();
1401+ }
1402+ }
1403+
1404+ explicit BinaryOp (Instruction::BinaryOps Opcode, Value *LHS, Value *RHS,
1405+ bool IsNSW = false , bool IsNUW = false )
1406+ : Opcode(Opcode), Operands({LHS, RHS}), IsNSW(IsNSW), IsNUW(IsNUW) {}
1407+ };
1408+
1409+ } // end anonymous namespace
1410+
1411+ static std::optional<BinaryOp> matchBinaryOp (Instruction *Op) {
1412+ switch (Op->getOpcode ()) {
1413+ case Instruction::Add:
1414+ case Instruction::Sub:
1415+ case Instruction::Mul:
1416+ return BinaryOp (Op);
1417+ case Instruction::Or: {
1418+ // Convert or disjoint into add nuw nsw.
1419+ if (cast<PossiblyDisjointInst>(Op)->isDisjoint ())
1420+ return BinaryOp (Instruction::Add, Op->getOperand (0 ), Op->getOperand (1 ),
1421+ /* IsNSW=*/ true , /* IsNUW=*/ true );
1422+ break ;
1423+ }
1424+ case Instruction::Shl: {
1425+ if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand (1 ))) {
1426+ unsigned BitWidth = cast<IntegerType>(SA->getType ())->getBitWidth ();
1427+
1428+ // If the shift count is not less than the bitwidth, the result of
1429+ // the shift is undefined. Don't try to analyze it, because the
1430+ // resolution chosen here may differ from the resolution chosen in
1431+ // other parts of the compiler.
1432+ if (SA->getValue ().ult (BitWidth)) {
1433+ // We can safely preserve the nuw flag in all cases. It's also safe to
1434+ // turn a nuw nsw shl into a nuw nsw mul. However, nsw in isolation
1435+ // requires special handling. It can be preserved as long as we're not
1436+ // left shifting by bitwidth - 1.
1437+ bool IsNUW = Op->hasNoUnsignedWrap ();
1438+ bool IsNSW = Op->hasNoSignedWrap () &&
1439+ (IsNUW || SA->getValue ().ult (BitWidth - 1 ));
1440+
1441+ ConstantInt *X =
1442+ ConstantInt::get (Op->getContext (),
1443+ APInt::getOneBitSet (BitWidth, SA->getZExtValue ()));
1444+ return BinaryOp (Instruction::Mul, Op->getOperand (0 ), X, IsNSW, IsNUW);
1445+ }
1446+ }
1447+
1448+ break ;
1449+ }
1450+ }
1451+
1452+ return std::nullopt ;
1453+ }
1454+
13841455// / No-wrap operations can transfer sign extension of their result to their
13851456// / operands. Generate the SCEV value for the widened operation without
13861457// / actually modifying the IR yet. If the expression after extending the
13871458// / operands is an AddRec for this loop, return the AddRec and the kind of
13881459// / extension used.
13891460WidenIV::WidenedRecTy
13901461WidenIV::getExtendedOperandRecurrence (WidenIV::NarrowIVDefUse DU) {
1391- // Handle the common case of add<nsw/nuw>
1392- const unsigned OpCode = DU.NarrowUse ->getOpcode ();
1393- // Only Add/Sub/Mul instructions supported yet.
1394- if (OpCode != Instruction::Add && OpCode != Instruction::Sub &&
1395- OpCode != Instruction::Mul)
1462+ auto Op = matchBinaryOp (DU.NarrowUse );
1463+ if (!Op)
13961464 return {nullptr , ExtendKind::Unknown};
13971465
1466+ assert ((Op->Opcode == Instruction::Add || Op->Opcode == Instruction::Sub ||
1467+ Op->Opcode == Instruction::Mul) &&
1468+ " Unexpected opcode" );
1469+
13981470 // One operand (NarrowDef) has already been extended to WideDef. Now determine
13991471 // if extending the other will lead to a recurrence.
1400- const unsigned ExtendOperIdx =
1401- DU.NarrowUse ->getOperand (0 ) == DU.NarrowDef ? 1 : 0 ;
1402- assert (DU.NarrowUse ->getOperand (1 -ExtendOperIdx) == DU.NarrowDef && " bad DU" );
1472+ const unsigned ExtendOperIdx = Op->Operands [0 ] == DU.NarrowDef ? 1 : 0 ;
1473+ assert (Op->Operands [1 - ExtendOperIdx] == DU.NarrowDef && " bad DU" );
14031474
1404- const OverflowingBinaryOperator *OBO =
1405- cast<OverflowingBinaryOperator>(DU.NarrowUse );
14061475 ExtendKind ExtKind = getExtendKind (DU.NarrowDef );
1407- if (!(ExtKind == ExtendKind::Sign && OBO-> hasNoSignedWrap () ) &&
1408- !(ExtKind == ExtendKind::Zero && OBO-> hasNoUnsignedWrap () )) {
1476+ if (!(ExtKind == ExtendKind::Sign && Op-> IsNSW ) &&
1477+ !(ExtKind == ExtendKind::Zero && Op-> IsNUW )) {
14091478 ExtKind = ExtendKind::Unknown;
14101479
14111480 // For a non-negative NarrowDef, we can choose either type of
14121481 // extension. We want to use the current extend kind if legal
14131482 // (see above), and we only hit this code if we need to check
14141483 // the opposite case.
14151484 if (DU.NeverNegative ) {
1416- if (OBO-> hasNoSignedWrap () ) {
1485+ if (Op-> IsNSW ) {
14171486 ExtKind = ExtendKind::Sign;
1418- } else if (OBO-> hasNoUnsignedWrap () ) {
1487+ } else if (Op-> IsNUW ) {
14191488 ExtKind = ExtendKind::Zero;
14201489 }
14211490 }
14221491 }
14231492
1424- const SCEV *ExtendOperExpr =
1425- SE->getSCEV (DU.NarrowUse ->getOperand (ExtendOperIdx));
1493+ const SCEV *ExtendOperExpr = SE->getSCEV (Op->Operands [ExtendOperIdx]);
14261494 if (ExtKind == ExtendKind::Sign)
14271495 ExtendOperExpr = SE->getSignExtendExpr (ExtendOperExpr, WideType);
14281496 else if (ExtKind == ExtendKind::Zero)
@@ -1443,7 +1511,7 @@ WidenIV::getExtendedOperandRecurrence(WidenIV::NarrowIVDefUse DU) {
14431511 if (ExtendOperIdx == 0 )
14441512 std::swap (lhs, rhs);
14451513 const SCEVAddRecExpr *AddRec =
1446- dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode (lhs, rhs, OpCode ));
1514+ dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode (lhs, rhs, Op-> Opcode ));
14471515
14481516 if (!AddRec || AddRec->getLoop () != L)
14491517 return {nullptr , ExtendKind::Unknown};
0 commit comments