@@ -67,6 +67,27 @@ struct AffineMinOpInterface
6767 expr.replaceDimsAndSymbols (dimReplacements, symReplacements);
6868 cstr.bound (value) <= bound;
6969 }
70+ // Get all constant lower bounds, choose minimum, and set lower bound to it.
71+ MLIRContext *ctx = op->getContext ();
72+ AffineMap map = minOp.getAffineMap ();
73+ SmallVector<Value> mapOperands = minOp.getOperands ();
74+ std::optional<int64_t > minBound;
75+ for (AffineExpr expr : map.getResults ()) {
76+ auto exprMap =
77+ AffineMap::get (map.getNumDims (), map.getNumSymbols (), expr, ctx);
78+ ValueBoundsConstraintSet::Variable exprVar (exprMap, mapOperands);
79+ FailureOr<int64_t > exprBound =
80+ cstr.computeConstantBound (presburger::BoundType::LB, exprVar,
81+ /* stopCondition=*/ nullptr );
82+ // If any LB cannot be computed, then the total LB cannot be known.
83+ if (failed (exprBound))
84+ return ;
85+ if (!minBound.has_value () || exprBound.value () < minBound.value ())
86+ minBound = exprBound.value ();
87+ }
88+ if (!minBound.has_value ())
89+ return ;
90+ cstr.bound (value) >= minBound.value ();
7091 };
7192};
7293
@@ -88,6 +109,27 @@ struct AffineMaxOpInterface
88109 expr.replaceDimsAndSymbols (dimReplacements, symReplacements);
89110 cstr.bound (value) >= bound;
90111 }
112+ // Get all constant upper bounds, choose maximum, and set upper bound to it.
113+ MLIRContext *ctx = op->getContext ();
114+ AffineMap map = maxOp.getAffineMap ();
115+ SmallVector<Value> mapOperands = maxOp.getOperands ();
116+ std::optional<int64_t > maxBound;
117+ for (AffineExpr expr : map.getResults ()) {
118+ auto exprMap =
119+ AffineMap::get (map.getNumDims (), map.getNumSymbols (), expr, ctx);
120+ ValueBoundsConstraintSet::Variable exprVar (exprMap, mapOperands);
121+ FailureOr<int64_t > exprBound = cstr.computeConstantBound (
122+ presburger::BoundType::UB, exprVar,
123+ /* stopCondition=*/ nullptr , /* closedUB=*/ true );
124+ // If any UB cannot be computed, then the total UB cannot be known.
125+ if (failed (exprBound))
126+ return ;
127+ if (!maxBound.has_value () || exprBound.value () > maxBound.value ())
128+ maxBound = exprBound.value ();
129+ }
130+ if (!maxBound.has_value ())
131+ return ;
132+ cstr.bound (value) <= maxBound.value ();
91133 };
92134};
93135
0 commit comments