Skip to content

Commit 9e494bc

Browse files
authored
[Preprocessing] Use value bounds for dim multiple checking (#20583)
This also adds support for udiv to util.assume value bounds constraints. In the future we can add more constraints to represent udiv, but for now this implements at least one working option.
1 parent 89da9fc commit 9e494bc

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,30 @@ struct UtilAssumeIntValueBoundsOpInterface
126126
auto [min, max] =
127127
assumeOp.getUnionedUnsignedRange(result.getResultNumber());
128128

129+
std::optional<int64_t> udiv =
130+
assumeOp.getUnionedUnsignedDivisor(result.getResultNumber());
131+
129132
if (min) {
130133
cstr.bound(result) >= *min;
131134
}
132135
if (max) {
133136
cstr.bound(result) <= *max;
134137
}
138+
if (udiv) {
139+
// To represent the divisibility guarantee, emit a bound clamping the
140+
// value to the udiv value. i.e.
141+
//
142+
// v == floordiv(v, udiv) * udiv
143+
//
144+
// Mod/divide folders can cleanup such terms with the appropriate bounds
145+
// query.
146+
AffineExpr expr =
147+
cstr.getExpr(assumeOp.getOperand(result.getResultNumber()));
148+
AffineExpr udivCst =
149+
getAffineConstantExpr(udiv.value(), op->getContext());
150+
AffineExpr clampExpr = expr.floorDiv(udivCst) * udivCst;
151+
cstr.bound(result) == clampExpr;
152+
}
135153
}
136154
};
137155

compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func.func @call_external(%arg0: index,
149149
%input_2d: tensor<?x20xf32>,
150150
%input_lb: tensor<100xf32>,
151151
%input_ub: tensor<3xf32>) {
152-
%0 = util.assume.int %arg0<umin = 12, umax = 16, udiv = 1> : index
152+
%0 = util.assume.int %arg0<umin = 12, umax = 16, udiv = 4> : index
153153
%input = tensor.empty(%0) : tensor<?xf32>
154154
// CHECK: call @external
155155
// CHECK-SAME: match_status = "both_matched"
@@ -178,6 +178,7 @@ module attributes {transform.with_named_sequence} {
178178
transform.match.operation_name %call ["func.call"] : !transform.any_op
179179
%in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
180180
transform.iree.match.dim_bounds %in0[0], umin = 5, umax = 20 : !transform.any_value
181+
transform.iree.match.dim_is_multiple_of %in0[0], 2 : !transform.any_value
181182
%0 = transform.param.constant "both_matched" -> !transform.any_param
182183
transform.yield %call, %0 : !transform.any_op, !transform.any_param
183184
}

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ DiagnosedSilenceableFailure
327327
IREE::transform_dialect::MatchDimIsMultipleOfOp::matchValue(
328328
Value current, transform::TransformResults &results,
329329
transform::TransformState &state) {
330+
MLIRContext *ctx = current.getContext();
330331
auto shapedType = dyn_cast<ShapedType>(current.getType());
331332
if (!shapedType) {
332333
return emitSilenceableError()
@@ -338,7 +339,20 @@ IREE::transform_dialect::MatchDimIsMultipleOfOp::matchValue(
338339
<< "dim " << dim << " out of range for shaped type " << shapedType;
339340
}
340341
int64_t size = getSize();
341-
if (shapedType.getShape()[dim] % size != 0) {
342+
ValueBoundsConstraintSet::Variable dimVar(current, dim);
343+
344+
// Check if current[dim] % size == 0. There are a couple of options for how
345+
// to do this (e.g. mul(floordiv)). Affine map canonicalizations are good
346+
// at dropping terms that statically divide the mod RHS so we go with this
347+
// one.
348+
AffineMap modMap = AffineMap::get(/*dimCount=*/0, /*symbolCount=*/1,
349+
getAffineSymbolExpr(0, ctx) %
350+
getAffineConstantExpr(size, ctx));
351+
ValueBoundsConstraintSet::Variable modVar(modMap, {dimVar});
352+
Builder b(ctx);
353+
FailureOr<bool> maybeFailed = ValueBoundsConstraintSet::areEqual(
354+
modVar, OpFoldResult{b.getIndexAttr(0)});
355+
if (failed(maybeFailed) || !maybeFailed.value()) {
342356
return emitSilenceableError()
343357
<< "dim " << dim << " of shaped type " << shapedType
344358
<< " is not a multiple of " << size;

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def MatchDimBoundsOp : Op<Transform_Dialect, "iree.match.dim_bounds",
118118
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
119119
}
120120

121+
// TODO: Combine with MatchDimBounds.
121122
def MatchDimIsMultipleOfOp : Op<Transform_Dialect, "iree.match.dim_is_multiple_of",
122123
[IsolatedFromAbove,
123124
MatchOpInterface,

0 commit comments

Comments
 (0)