Skip to content

Commit 66940f9

Browse files
authored
[Preprocessing] Add op for matching dimension bounds (#20502)
Currently there is only support for matching static shapes and dimension alignment. This allows matching ops based on some minimum/maximum size for problem shape sensitive strategies.
1 parent a5cafa3 commit 66940f9

File tree

5 files changed

+161
-2
lines changed

5 files changed

+161
-2
lines changed

compiler/src/iree/compiler/Preprocessing/Common/test/preprocessing_match_ops.mlir

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,83 @@ module attributes {transform.with_named_sequence} {
139139

140140
// -----
141141

142+
func.func private @external(%arg0: tensor<?xf32>)
143+
func.func private @external_lb(%arg0: tensor<100xf32>)
144+
func.func private @external_ub(%arg0: tensor<3xf32>)
145+
func.func private @external_2d(%arg0: tensor<?x20xf32>)
146+
147+
// CHECK-LABEL: func.func @call_external
148+
func.func @call_external(%arg0: index,
149+
%input_2d: tensor<?x20xf32>,
150+
%input_lb: tensor<100xf32>,
151+
%input_ub: tensor<3xf32>) {
152+
%0 = util.assume.int %arg0<umin = 12, umax = 16, udiv = 1> : index
153+
%input = tensor.empty(%0) : tensor<?xf32>
154+
// CHECK: call @external
155+
// CHECK-SAME: match_status = "both_matched"
156+
func.call @external(%input) {match_status = "unmatched"} : (tensor<?xf32>) -> ()
157+
// CHECK: call @external_2d
158+
// CHECK-SAME: match_status = "dim1_matched"
159+
func.call @external_2d(%input_2d) {match_status = "unmatched"} : (tensor<?x20xf32>) -> ()
160+
// CHECK: call @external_lb
161+
// CHECK-SAME: match_status = "lb_matched"
162+
func.call @external_lb(%input_lb) {match_status = "unmatched"} : (tensor<100xf32>) -> ()
163+
// CHECK: call @external_ub
164+
// CHECK-SAME: match_status = "ub_matched"
165+
func.call @external_ub(%input_ub) {match_status = "unmatched"} : (tensor<3xf32>) -> ()
166+
return
167+
}
168+
169+
module attributes {transform.with_named_sequence} {
170+
transform.named_sequence @dim1_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
171+
transform.match.operation_name %call ["func.call"] : !transform.any_op
172+
%in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
173+
transform.iree.match.dim_bounds %in0[1], umin = 20, umax = 20 : !transform.any_value
174+
%0 = transform.param.constant "dim1_matched" -> !transform.any_param
175+
transform.yield %call, %0 : !transform.any_op, !transform.any_param
176+
}
177+
transform.named_sequence @both_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
178+
transform.match.operation_name %call ["func.call"] : !transform.any_op
179+
%in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
180+
transform.iree.match.dim_bounds %in0[0], umin = 5, umax = 20 : !transform.any_value
181+
%0 = transform.param.constant "both_matched" -> !transform.any_param
182+
transform.yield %call, %0 : !transform.any_op, !transform.any_param
183+
}
184+
transform.named_sequence @lb_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
185+
transform.match.operation_name %call ["func.call"] : !transform.any_op
186+
%in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
187+
transform.iree.match.dim_bounds %in0[0], umin = 75, none : !transform.any_value
188+
%0 = transform.param.constant "lb_matched" -> !transform.any_param
189+
transform.yield %call, %0 : !transform.any_op, !transform.any_param
190+
}
191+
transform.named_sequence @ub_match(%call: !transform.any_op {transform.readonly}) -> (!transform.any_op, !transform.any_param) {
192+
transform.match.operation_name %call ["func.call"] : !transform.any_op
193+
%in0 = transform.get_operand %call[0] : (!transform.any_op) -> !transform.any_value
194+
transform.iree.match.dim_bounds %in0[0], none, umax = 4 : !transform.any_value
195+
%0 = transform.param.constant "ub_matched" -> !transform.any_param
196+
transform.yield %call, %0 : !transform.any_op, !transform.any_param
197+
}
198+
199+
transform.named_sequence @annotate(%call: !transform.any_op {transform.readonly},
200+
%note: !transform.any_param {transform.readonly}) {
201+
transform.annotate %call "match_status" = %note : !transform.any_op, !transform.any_param
202+
transform.yield
203+
}
204+
205+
transform.named_sequence @__transform_main(%module: !transform.any_op) {
206+
%func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
207+
transform.foreach_match in %module
208+
@dim1_match -> @annotate,
209+
@both_match -> @annotate,
210+
@lb_match -> @annotate,
211+
@ub_match -> @annotate
212+
: (!transform.any_op) -> (!transform.any_op)
213+
transform.yield
214+
}
215+
}
216+
217+
// -----
218+
142219
module attributes {transform.with_named_sequence} {
143220

144221
// CHECK: func.func @matmul_repeated_operand

compiler/src/iree/compiler/Preprocessing/TransformExtensions/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,6 @@ iree_compiler_cc_library(
6666
"@llvm-project//mlir:TransformDialect",
6767
"@llvm-project//mlir:TransformDialectInterfaces",
6868
"@llvm-project//mlir:TransformUtils",
69+
"@llvm-project//mlir:ValueBoundsOpInterface",
6970
],
7071
)

compiler/src/iree/compiler/Preprocessing/TransformExtensions/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ iree_cc_library(
4040
MLIRTransformDialect
4141
MLIRTransformDialectInterfaces
4242
MLIRTransformUtils
43+
MLIRValueBoundsOpInterface
4344
iree::compiler::Utils
4445
PUBLIC
4546
)

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensions.cpp

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/IRMapping.h"
1515
#include "mlir/IR/OpDefinition.h"
1616
#include "mlir/IR/Value.h"
17+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
1718

1819
namespace mlir::iree_compiler {
1920

@@ -271,6 +272,53 @@ IREE::transform_dialect::MatchCastCompatibleTypesOp::matchValue(
271272
return DiagnosedSilenceableFailure::success();
272273
}
273274

275+
//===----------------------------------------------------------------------===//
276+
// MatchDimBoundsOp
277+
//===----------------------------------------------------------------------===//
278+
279+
DiagnosedSilenceableFailure
280+
IREE::transform_dialect::MatchDimBoundsOp::matchValue(
281+
Value current, transform::TransformResults &results,
282+
transform::TransformState &state) {
283+
auto shapedType = dyn_cast<ShapedType>(current.getType());
284+
if (!shapedType) {
285+
return emitSilenceableError()
286+
<< "type " << current.getType() << " is not a shaped type";
287+
}
288+
int64_t dim = getDim();
289+
if (dim >= shapedType.getRank()) {
290+
return emitSilenceableError()
291+
<< "dim " << dim << " out of range for shaped type " << shapedType;
292+
}
293+
if (std::optional<int64_t> lb = getLowerBound()) {
294+
auto constantLb = ValueBoundsConstraintSet::computeConstantBound(
295+
presburger::BoundType::LB, {current, /*dim=*/dim},
296+
/*stopCondition=*/nullptr, /*closedLB=*/true);
297+
if (failed(constantLb)) {
298+
return emitSilenceableError()
299+
<< "failed to compute constant lower bound for dim " << dim;
300+
}
301+
if (lb.value() > constantLb.value()) {
302+
return emitSilenceableError()
303+
<< "dim " << dim << " is not >= " << lb.value();
304+
}
305+
}
306+
if (std::optional<int64_t> ub = getUpperBound()) {
307+
auto constantUb = ValueBoundsConstraintSet::computeConstantBound(
308+
presburger::BoundType::UB, {current, /*dim=*/dim},
309+
/*stopCondition=*/nullptr, /*closedUB=*/true);
310+
if (failed(constantUb)) {
311+
return emitSilenceableError()
312+
<< "failed to compute constant upper bound for dim " << dim;
313+
}
314+
if (ub.value() < constantUb.value()) {
315+
return emitSilenceableError()
316+
<< "dim " << dim << " is not <= " << ub.value();
317+
}
318+
}
319+
return DiagnosedSilenceableFailure::success();
320+
}
321+
274322
//===----------------------------------------------------------------------===//
275323
// MatchDimIsMultipleOfOp
276324
//===----------------------------------------------------------------------===//
@@ -285,7 +333,7 @@ IREE::transform_dialect::MatchDimIsMultipleOfOp::matchValue(
285333
<< "type " << current.getType() << " is not a shaped type";
286334
}
287335
int64_t dim = getDim();
288-
if (dim > shapedType.getRank()) {
336+
if (dim >= shapedType.getRank()) {
289337
return emitSilenceableError()
290338
<< "dim " << dim << " out of range for shaped type " << shapedType;
291339
}

compiler/src/iree/compiler/Preprocessing/TransformExtensions/PreprocessingExtensionsOps.td

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,45 @@ def MatchCastCompatibleTypesOp : Op<Transform_Dialect, "iree.match.cast_compatib
8686
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
8787
}
8888

89+
def MatchDimBoundsOp : Op<Transform_Dialect, "iree.match.dim_bounds",
90+
[IsolatedFromAbove,
91+
MatchOpInterface,
92+
SingleValueMatcher,
93+
MemoryEffectsOpInterface]> {
94+
let summary =
95+
"Checks whether the size of a dim is within given bounds";
96+
let description = [{
97+
Checks whether a dim is within a specified lower and upper bound.
98+
99+
#### Return modes
100+
101+
Succeeds if the value's type is compatible with the target type, and
102+
produces a silenceable failure otherwise. Produces a definite failure
103+
if the operand is not associated with a single payload value.
104+
}];
105+
106+
let arguments = (ins TransformValueHandleTypeInterface:$operand_handle,
107+
I64Attr:$dim,
108+
OptionalAttr<I64Attr>:$lower_bound,
109+
OptionalAttr<I64Attr>:$upper_bound);
110+
111+
// `<=` isn't a valid literal so use `le` instead.
112+
let assemblyFormat = [{
113+
$operand_handle `[` $dim `]` `,` (`umin` `=` $lower_bound^):(`none`)?
114+
`,` (`umax` `=` $upper_bound^):(`none`)? attr-dict `:` type($operand_handle)
115+
}];
116+
let extraClassDeclaration = SingleValueMatcher.extraDeclaration;
117+
118+
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
119+
}
120+
89121
def MatchDimIsMultipleOfOp : Op<Transform_Dialect, "iree.match.dim_is_multiple_of",
90122
[IsolatedFromAbove,
91123
MatchOpInterface,
92124
SingleValueMatcher,
93125
MemoryEffectsOpInterface]> {
94126
let summary =
95-
"Checks if the body of the target op matches the body of the single contained op";
127+
"Checks the static size of a dim is divisible by a given value";
96128
let description = [{
97129
Checks whether the given dimension given shaped value is a multiple of the
98130
given size.

0 commit comments

Comments
 (0)