Skip to content

Commit 9351ad6

Browse files
authored
[MLIR][Transform][SMT] Allow for declarative computations in schedules (#160895)
By allowing `transform.smt.constrain_params`'s region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed. The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.
1 parent 34ed1dc commit 9351ad6

File tree

8 files changed

+253
-29
lines changed

8 files changed

+253
-29
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
220220
Pure,
221221
Terminator,
222222
ReturnLike,
223-
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
224-
"smt::ForallOp", "smt::ExistsOp"]>,
225223
]> {
226224
let summary = "terminator operation for various regions of SMT operations";
227225
let arguments = (ins Variadic<AnyType>:$values);

mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
1314
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1415
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1516
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,28 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1616
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
1717
DeclareOpInterfaceMethods<TransformOpInterface>,
1818
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
19-
NoTerminator
19+
SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
2020
]> {
2121
let cppNamespace = [{ mlir::transform::smt }];
2222

2323
let summary = "Express contraints on params interpreted as symbolic values";
2424
let description = [{
2525
Allows expressing constraints on params using the SMT dialect.
2626

27-
Each Transform dialect param provided as an operand has a corresponding
27+
Each Transform-dialect param provided as an operand has a corresponding
2828
argument of SMT-type in the region. The SMT-Dialect ops in the region use
29-
these arguments as operands.
29+
these params-as-SMT-vars as operands, thereby expressing relevant
30+
constraints on their allowed values.
31+
32+
Computations w.r.t. passed-in params can also be expressed through the
33+
region's SMT-ops. Namely, the constraints express relationships to other
34+
SMT-variables which can then be yielded from the region (with `smt.yield`).
3035

3136
The semantics of this op is that all the ops in the region together express
3237
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
3338
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
34-
op succeeds.
39+
op succeeds and any one satisfying assignment is used to map the
40+
SMT-variables yielded in the region to `transform.param`s.
3541

3642
---
3743

@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
4248
}];
4349

4450
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
51+
let results = (outs Variadic<TransformParamTypeInterface>:$results);
4552
let regions = (region SizedRegion<1>:$body);
4653
let assemblyFormat =
47-
"`(` $params `)` attr-dict `:` type(operands) $body";
54+
"`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
4855

4956
let hasVerifier = 1;
5057
}

mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
1010
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11-
#include "mlir/Dialect/Transform/IR/TransformOps.h"
12-
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
11+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
12+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1313

1414
using namespace mlir;
1515

@@ -23,6 +23,7 @@ using namespace mlir;
2323
void transform::smt::ConstrainParamsOp::getEffects(
2424
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2525
onlyReadsHandle(getParamsMutable(), effects);
26+
producesHandle(getResults(), effects);
2627
}
2728

2829
DiagnosedSilenceableFailure
@@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
3738
// and allow for users to attach their own implementation, which would,
3839
// e.g., translate the ops to SMTLIB and hand that over to the user's
3940
// favourite solver. This requires changes to the dialect's verifier.
40-
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
41+
return emitSilenceableFailure(getLoc())
42+
<< "op does not have interpreted semantics yet";
4143
}
4244

4345
LogicalResult transform::smt::ConstrainParamsOp::verify() {
46+
auto yieldTerminator =
47+
dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
48+
if (!yieldTerminator)
49+
return emitOpError() << "expected '"
50+
<< mlir::smt::YieldOp::getOperationName()
51+
<< "' as terminator";
52+
53+
auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
54+
Type paramType, StringRef paramDesc,
55+
auto *atOp) -> InFlightDiagnostic {
56+
if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
57+
smtType))
58+
return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
59+
<< " is expected to be either a !smt.bool, a "
60+
"!smt.int, or a !smt.bv";
61+
62+
assert(isa<TransformParamTypeInterface>(paramType) &&
63+
"ODS specifies params' type should implement param interface");
64+
if (isa<transform::AnyParamType>(paramType))
65+
return {}; // No further checks can be done.
66+
67+
// NB: This cast must succeed as long as the only implementors of
68+
// TransformParamTypeInterface are AnyParamType and ParamType.
69+
Type typeWrappedByParam = cast<ParamType>(paramType).getType();
70+
71+
if (isa<mlir::smt::IntType>(smtType)) {
72+
if (!isa<IntegerType>(typeWrappedByParam))
73+
return atOp->emitOpError()
74+
<< "the type of " << smtDesc << " #" << idx
75+
<< " is !smt.int though the corresponding " << paramDesc
76+
<< " type (" << paramType << ") is not wrapping an integer type";
77+
} else if (isa<mlir::smt::BoolType>(smtType)) {
78+
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
79+
if (!wrappedIntType || wrappedIntType.getWidth() != 1)
80+
return atOp->emitOpError()
81+
<< "the type of " << smtDesc << " #" << idx
82+
<< " is !smt.bool though the corresponding " << paramDesc
83+
<< " type (" << paramType << ") is not wrapping i1";
84+
} else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
85+
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
86+
if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
87+
return atOp->emitOpError()
88+
<< "the type of " << smtDesc << " #" << idx << " is " << smtType
89+
<< " though the corresponding " << paramDesc << " type ("
90+
<< paramType
91+
<< ") is not wrapping an integer type of the same bitwidth";
92+
}
93+
94+
return {};
95+
};
96+
4497
if (getOperands().size() != getBody().getNumArguments())
4598
return emitOpError(
4699
"must have the same number of block arguments as operands");
47100

101+
for (auto [idx, operandType, blockArgType] :
102+
llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
103+
InFlightDiagnostic typeCheckResult =
104+
checkTypes(idx, blockArgType, "block arg", operandType, "operand",
105+
/*atOp=*/this);
106+
if (LogicalResult(typeCheckResult).failed())
107+
return typeCheckResult;
108+
}
109+
48110
for (auto &op : getBody().getOps()) {
49111
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
50112
return emitOpError(
51113
"ops contained in region should belong to SMT-dialect");
52114
}
53115

116+
if (yieldTerminator->getNumOperands() != getNumResults())
117+
return yieldTerminator.emitOpError()
118+
<< "expected terminator to have as many operands as the parent op "
119+
"has results";
120+
121+
for (auto [idx, termOperandType, resultType] : llvm::enumerate(
122+
yieldTerminator->getOperands().getType(), getResultTypes())) {
123+
InFlightDiagnostic typeCheckResult =
124+
checkTypes(idx, termOperandType, "terminator operand",
125+
cast<transform::ParamType>(resultType), "result",
126+
/*atOp=*/&yieldTerminator);
127+
if (LogicalResult(typeCheckResult).failed())
128+
return typeCheckResult;
129+
}
130+
54131
return success();
55132
}

mlir/python/mlir/dialects/transform/smt.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
class ConstrainParamsOp(ConstrainParamsOp):
2020
def __init__(
2121
self,
22+
results: Sequence[Type],
2223
params: Sequence[transform.AnyParamType],
2324
arg_types: Sequence[Type],
2425
loc=None,
@@ -27,6 +28,7 @@ def __init__(
2728
if len(params) != len(arg_types):
2829
raise ValueError(f"{params=} not same length as {arg_types=}")
2930
super().__init__(
31+
results,
3032
params,
3133
loc=loc,
3234
ip=ip,
@@ -36,3 +38,13 @@ def __init__(
3638
@property
3739
def body(self) -> Block:
3840
return self.regions[0].blocks[0]
41+
42+
43+
def constrain_params(
44+
results: Sequence[Type],
45+
params: Sequence[transform.AnyParamType],
46+
arg_types: Sequence[Type],
47+
loc=None,
48+
ip=None,
49+
):
50+
return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)

mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,40 @@
11
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics
22

3+
// CHECK-LABEL: @incorrect terminator
4+
module attributes {transform.with_named_sequence} {
5+
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
6+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
7+
// expected-error@below {{op expected 'smt.yield' as terminator}}
8+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
9+
^bb0(%param_as_smt_var: !smt.int):
10+
transform.yield
11+
}
12+
transform.yield
13+
}
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: @operands_not_one_to_one_with_vars
19+
module attributes {transform.with_named_sequence} {
20+
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
21+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
22+
// expected-error@below {{must have the same number of block arguments as operands}}
23+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
24+
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
25+
}
26+
transform.yield
27+
}
28+
}
29+
30+
// -----
31+
332
// CHECK-LABEL: @constraint_not_using_smt_ops
433
module attributes {transform.with_named_sequence} {
534
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
635
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
736
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
8-
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
37+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
938
^bb0(%param_as_smt_var: !smt.int):
1039
%c4 = arith.constant 4 : i32
1140
// This is the kind of thing one might think works:
@@ -17,13 +46,90 @@ module attributes {transform.with_named_sequence} {
1746

1847
// -----
1948

20-
// CHECK-LABEL: @operands_not_one_to_one_with_vars
49+
// CHECK-LABEL: @results_not_one_to_one_with_vars
2150
module attributes {transform.with_named_sequence} {
22-
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
51+
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
2352
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
24-
// expected-error@below {{must have the same number of block arguments as operands}}
25-
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
53+
transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
2654
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
55+
// expected-error@below {{expected terminator to have as many operands as the parent op has results}}
56+
smt.yield %param_as_smt_var : !smt.int
57+
}
58+
transform.yield
59+
}
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: @non_smt_type_block_args
65+
module attributes {transform.with_named_sequence} {
66+
transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
67+
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
68+
// expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
69+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
70+
^bb0(%param_as_smt_var: !transform.param<i8>):
71+
smt.yield %param_as_smt_var : !transform.param<i8>
72+
}
73+
transform.yield
74+
}
75+
}
76+
77+
78+
// -----
79+
80+
// CHECK-LABEL: @mismatched_arg_type_bool
81+
module attributes {transform.with_named_sequence} {
82+
transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
83+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
84+
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
85+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
86+
^bb0(%param_as_smt_var: !smt.bool):
87+
smt.yield %param_as_smt_var : !smt.bool
88+
}
89+
transform.yield
90+
}
91+
}
92+
93+
// -----
94+
95+
// CHECK-LABEL: @mismatched_arg_type_bitvector
96+
module attributes {transform.with_named_sequence} {
97+
transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
98+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
99+
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
100+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
101+
^bb0(%param_as_smt_var: !smt.bv<8>):
102+
smt.yield %param_as_smt_var : !smt.bv<8>
103+
}
104+
transform.yield
105+
}
106+
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: @mismatched_result_type_bool
111+
module attributes {transform.with_named_sequence} {
112+
transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
113+
%param_as_param = transform.param.constant 1 -> !transform.param<i1>
114+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
115+
^bb0(%param_as_smt_var: !smt.bool):
116+
// expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
117+
smt.yield %param_as_smt_var : !smt.bool
118+
}
119+
transform.yield
120+
}
121+
}
122+
123+
// -----
124+
125+
// CHECK-LABEL: @mismatched_result_type_bitvector
126+
module attributes {transform.with_named_sequence} {
127+
transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
128+
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
129+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
130+
^bb0(%param_as_smt_var: !smt.bv<8>):
131+
// expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
132+
smt.yield %param_as_smt_var : !smt.bv<8>
27133
}
28134
transform.yield
29135
}

0 commit comments

Comments
 (0)