@@ -20,6 +20,16 @@ namespace {
2020struct ForOpInterface
2121 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2222
23+ static AffineExpr getTripCountExpr (scf::ForOp forOp,
24+ ValueBoundsConstraintSet &cstr) {
25+ AffineExpr lbExpr = cstr.getExpr (forOp.getLowerBound ());
26+ AffineExpr ubExpr = cstr.getExpr (forOp.getUpperBound ());
27+ AffineExpr stepExpr = cstr.getExpr (forOp.getStep ());
28+ AffineExpr tripCountExpr =
29+ AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
30+ return tripCountExpr;
31+ }
32+
2333 // / Populate bounds of values/dimensions for iter_args/OpResults. If the
2434 // / value/dimension size does not change in an iteration, we can deduce that
2535 // / it the same as the initial value/dimension.
@@ -77,11 +87,7 @@ struct ForOpInterface
7787 // `value` is result of `forOp`, we can prove that:
7888 // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
7989 // Where trip_count is (ub - lb) / step.
80- AffineExpr lbExpr = cstr.getExpr (forOp.getLowerBound ());
81- AffineExpr ubExpr = cstr.getExpr (forOp.getUpperBound ());
82- AffineExpr stepExpr = cstr.getExpr (forOp.getStep ());
83- AffineExpr tripCountExpr =
84- AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
90+ AffineExpr tripCountExpr = getTripCountExpr (forOp, cstr);
8591 AffineExpr oneIterAdvanceExpr =
8692 cstr.getExpr (yieldedValue) - cstr.getExpr (iterArg);
8793 cstr.bound (value) ==
@@ -93,9 +99,18 @@ struct ForOpInterface
9399 auto forOp = cast<ForOp>(op);
94100
95101 if (value == forOp.getInductionVar ()) {
96- // TODO: Take into account step size.
97102 cstr.bound (value) >= forOp.getLowerBound ();
98103 cstr.bound (value) < forOp.getUpperBound ();
104+ // iv <= lb + ((ub-lb)/step - 1) * step
105+ // This bound does not replace the `iv < ub` constraint mentioned above,
106+ // since constraints involving the multiplication of two constraint set
107+ // dimensions are not supported.
108+ AffineExpr tripCountMinusOne =
109+ getTripCountExpr (forOp, cstr) - cstr.getExpr (1 );
110+ AffineExpr computedUpperBound =
111+ cstr.getExpr (forOp.getLowerBound ()) +
112+ AffineExpr (tripCountMinusOne * cstr.getExpr (forOp.getStep ()));
113+ cstr.bound (value) <= computedUpperBound;
99114 return ;
100115 }
101116
0 commit comments