Skip to content

Commit 34e3e0c

Browse files
committed
[mlir] Extend affine.min/max ValueBoundsOpInterfaceImpls
Signed-off-by: Max Dawkins <[email protected]>
1 parent 2f1bc68 commit 34e3e0c

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,27 @@ struct AffineMinOpInterface
6767
expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
6868
cstr.bound(value) <= bound;
6969
}
70+
// Get all constant lower bounds, choose minimum, and set lower bound to it.
71+
MLIRContext *ctx = op->getContext();
72+
AffineMap map = minOp.getAffineMap();
73+
SmallVector<Value> mapOperands = minOp.getOperands();
74+
std::optional<int64_t> minBound;
75+
for (AffineExpr expr : map.getResults()) {
76+
auto exprMap =
77+
AffineMap::get(map.getNumDims(), map.getNumSymbols(), expr, ctx);
78+
ValueBoundsConstraintSet::Variable exprVar(exprMap, mapOperands);
79+
FailureOr<int64_t> exprBound =
80+
cstr.computeConstantBound(presburger::BoundType::LB, exprVar,
81+
/*stopCondition=*/nullptr);
82+
// If any LB cannot be computed, then the total LB cannot be known.
83+
if (failed(exprBound))
84+
return;
85+
if (!minBound.has_value() || exprBound.value() < minBound.value())
86+
minBound = exprBound.value();
87+
}
88+
if (!minBound.has_value())
89+
return;
90+
cstr.bound(value) >= minBound.value();
7091
};
7192
};
7293

@@ -88,6 +109,27 @@ struct AffineMaxOpInterface
88109
expr.replaceDimsAndSymbols(dimReplacements, symReplacements);
89110
cstr.bound(value) >= bound;
90111
}
112+
// Get all constant upper bounds, choose maximum, and set upper bound to it.
113+
MLIRContext *ctx = op->getContext();
114+
AffineMap map = maxOp.getAffineMap();
115+
SmallVector<Value> mapOperands = maxOp.getOperands();
116+
std::optional<int64_t> maxBound;
117+
for (AffineExpr expr : map.getResults()) {
118+
auto exprMap =
119+
AffineMap::get(map.getNumDims(), map.getNumSymbols(), expr, ctx);
120+
ValueBoundsConstraintSet::Variable exprVar(exprMap, mapOperands);
121+
FailureOr<int64_t> exprBound = cstr.computeConstantBound(
122+
presburger::BoundType::UB, exprVar,
123+
/*stopCondition=*/nullptr, /*closedUB=*/true);
124+
// If any UB cannot be computed, then the total UB cannot be known.
125+
if (failed(exprBound))
126+
return;
127+
if (!maxBound.has_value() || exprBound.value() > maxBound.value())
128+
maxBound = exprBound.value();
129+
}
130+
if (!maxBound.has_value())
131+
return;
132+
cstr.bound(value) <= maxBound.value();
91133
};
92134
};
93135

mlir/test/Dialect/Affine/value-bounds-op-interface-impl.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,19 @@ func.func @affine_max_ub(%a: index) -> (index) {
3838

3939
// -----
4040

41+
// CHECK-LABEL: func @affine_max_const_ub(
42+
// CHECK-SAME: %[[a:.*]]: index
43+
// CHECK: %[[c5:.*]] = arith.constant 5 : index
44+
// CHECK: return %[[c5]]
45+
func.func @affine_max_const_ub(%a: index) -> (index) {
46+
%0 = affine.min affine_map<(d0) -> (d0, 4)>(%a)
47+
%1 = affine.max affine_map<(d0) -> (d0, 2)>(%0)
48+
%2 = "test.reify_bound"(%1) {type = "UB"}: (index) -> (index)
49+
return %2 : index
50+
}
51+
52+
// -----
53+
4154
// CHECK-LABEL: func @affine_min_ub(
4255
// CHECK-SAME: %[[a:.*]]: index
4356
// CHECK: %[[c3:.*]] = arith.constant 3 : index
@@ -61,6 +74,19 @@ func.func @affine_min_lb(%a: index) -> (index) {
6174

6275
// -----
6376

77+
// CHECK-LABEL: func @affine_min_const_lb(
78+
// CHECK-SAME: %[[a:.*]]: index
79+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
80+
// CHECK: return %[[c0]]
81+
func.func @affine_min_const_lb(%a: index) -> (index) {
82+
%0 = affine.max affine_map<(d0) -> (d0, 0)>(%a)
83+
%1 = affine.min affine_map<(d0) -> (d0, 2)>(%0)
84+
%2 = "test.reify_bound"(%1) {type = "LB"}: (index) -> (index)
85+
return %2 : index
86+
}
87+
88+
// -----
89+
6490
// CHECK-LABEL: func @composed_affine_apply(
6591
// CHECK: %[[cst:.*]] = arith.constant -8 : index
6692
// CHECK: return %[[cst]]

0 commit comments

Comments
 (0)