@@ -3217,7 +3217,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
32173217 }
32183218
32193219 // Try to fold (C1 * D /u C2) -> C1/C2 * D, if C1 and C2 are powers-of-2,
3220- // D is a multiple of C2, and C1 is a multiple of C2.
3220+ // D is a multiple of C2, and C1 is a multiple of C2. If C2 is a multiple
3221+ // of C1, fold to (D /u (C2 /u C1)).
32213222 const SCEV *D;
32223223 APInt C1V = LHSC->getAPInt();
32233224 // (C1 * D /u C2) == -1 * -C1 * D /u C2 when C1 != INT_MIN.
@@ -3226,10 +3227,15 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
32263227 const SCEVConstant *C2;
32273228 if (C1V.isPowerOf2() &&
32283229 match(Ops[1], m_scev_UDiv(m_SCEV(D), m_SCEVConstant(C2))) &&
3229- C2->getAPInt().isPowerOf2() && C1V.uge(C2->getAPInt()) &&
3230+ C2->getAPInt().isPowerOf2() &&
32303231 C1V.logBase2() <= getMinTrailingZeros(D)) {
3231- const SCEV *NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3232- return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
3232+ const SCEV *NewMul = nullptr;
3233+ if (C1V.uge(C2->getAPInt()))
3234+ NewMul = getMulExpr(getUDivExpr(getConstant(C1V), C2), D);
3235+ else if (C1V.ugt(1))
3236+ NewMul = getUDivExpr(D, getUDivExpr(C2, getConstant(C1V)));
3237+ if (NewMul)
3238+ return C1V == LHSC->getAPInt() ? NewMul : getNegativeSCEV(NewMul);
32333239 }
32343240 }
32353241 }
0 commit comments