Skip to content

Commit 9e8a2e2

Browse files
committed
[MLIR][Transform][SMT] Allow for declarative computations in schedules
By allowing `transform.smt.constrain_params`'s region to yield SMT vars, we op instances declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed. The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.
1 parent 34ed1dc commit 9e8a2e2

File tree

8 files changed

+205
-26
lines changed

8 files changed

+205
-26
lines changed

mlir/include/mlir/Dialect/SMT/IR/SMTOps.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
220220
Pure,
221221
Terminator,
222222
ReturnLike,
223-
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
224-
"smt::ForallOp", "smt::ExistsOp"]>,
225223
]> {
226224
let summary = "terminator operation for various regions of SMT operations";
227225
let arguments = (ins Variadic<AnyType>:$values);

mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
13+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
1314
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1415
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1516
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,28 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1616
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
1717
DeclareOpInterfaceMethods<TransformOpInterface>,
1818
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
19-
NoTerminator
19+
SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
2020
]> {
2121
let cppNamespace = [{ mlir::transform::smt }];
2222

2323
let summary = "Express contraints on params interpreted as symbolic values";
2424
let description = [{
2525
Allows expressing constraints on params using the SMT dialect.
2626

27-
Each Transform dialect param provided as an operand has a corresponding
27+
Each Transform-dialect param provided as an operand has a corresponding
2828
argument of SMT-type in the region. The SMT-Dialect ops in the region use
29-
these arguments as operands.
29+
these params-as-SMT-vars as operands, thereby expressing relevant
30+
constraints on their allowed values.
31+
32+
Computations w.r.t. passed-in params can also be expressed through the
33+
region's SMT-ops. Namely, the constraints express relationships to other
34+
SMT-variables which can then be yielded from the region (with `smt.yield`).
3035

3136
The semantics of this op is that all the ops in the region together express
3237
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
3338
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
34-
op succeeds.
39+
op succeeds and any one satisfying assignment is used to map the
40+
SMT-variables yielded in the region to `transform.param`s.
3541

3642
---
3743

@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
4248
}];
4349

4450
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
51+
let results = (outs Variadic<TransformParamTypeInterface>:$results);
4552
let regions = (region SizedRegion<1>:$body);
4653
let assemblyFormat =
47-
"`(` $params `)` attr-dict `:` type(operands) $body";
54+
"`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
4855

4956
let hasVerifier = 1;
5057
}

mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
1010
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11-
#include "mlir/Dialect/Transform/IR/TransformOps.h"
12-
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
11+
#include "mlir/Dialect/SMT/IR/SMTOps.h"
12+
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1313

1414
using namespace mlir;
1515

@@ -23,6 +23,7 @@ using namespace mlir;
2323
void transform::smt::ConstrainParamsOp::getEffects(
2424
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2525
onlyReadsHandle(getParamsMutable(), effects);
26+
producesHandle(getResults(), effects);
2627
}
2728

2829
DiagnosedSilenceableFailure
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
3738
// and allow for users to attach their own implementation, which would,
3839
// e.g., translate the ops to SMTLIB and hand that over to the user's
3940
// favourite solver. This requires changes to the dialect's verifier.
40-
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
41+
return emitSilenceableFailure(getLoc())
42+
<< "op does not have interpreted semantics yet";
4143
}
4244

4345
LogicalResult transform::smt::ConstrainParamsOp::verify() {
46+
auto yieldTerminator =
47+
llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
48+
if (!yieldTerminator)
49+
return emitOpError() << "expected '"
50+
<< mlir::smt::YieldOp::getOperationName()
51+
<< "' as terminator";
52+
4453
if (getOperands().size() != getBody().getNumArguments())
4554
return emitOpError(
4655
"must have the same number of block arguments as operands");
4756

57+
for (auto [i, operandType, blockArgType] :
58+
llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
59+
getOperandTypes(), getBody().getArgumentTypes())) {
60+
if (isa<transform::AnyParamType>(operandType))
61+
continue; // No type checking as operand is of !transform.any_param type.
62+
auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
63+
if (!paramOperandType)
64+
return emitOpError() << "operand type #" << i
65+
<< " is not a !transform.param";
66+
Type wrappedOperandType = paramOperandType.getType();
67+
68+
if (isa<mlir::smt::IntType>(blockArgType)) {
69+
if (!isa<IntegerType>(paramOperandType.getType()))
70+
return emitOpError()
71+
<< "the type of block arg #" << i
72+
<< " is !smt.int though the corresponding operand type ("
73+
<< operandType << ") is not wrapping an integer type";
74+
} else if (isa<mlir::smt::BoolType>(blockArgType)) {
75+
auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
76+
if (!intOperandType || intOperandType.getWidth() != 1)
77+
return emitOpError()
78+
<< "the type of block arg #" << i
79+
<< " is !smt.bool though the corresponding operand type ("
80+
<< operandType << ") is not wrapping i1 (i.e. bool)";
81+
} else if (auto bvBlockArgType =
82+
dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
83+
auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
84+
if (!intOperandType ||
85+
intOperandType.getWidth() != bvBlockArgType.getWidth())
86+
return emitOpError()
87+
<< "the type of block arg #" << i << " is " << blockArgType
88+
<< " though the corresponding operand type (" << operandType
89+
<< ") is not wrapping an integer type of the same bitwidth";
90+
}
91+
}
92+
4893
for (auto &op : getBody().getOps()) {
4994
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
5095
return emitOpError(
5196
"ops contained in region should belong to SMT-dialect");
5297
}
5398

99+
if (getOperands().size() != getBody().getNumArguments())
100+
return emitOpError(
101+
"must have the same number of block arguments as operands");
102+
103+
if (yieldTerminator->getNumOperands() != getNumResults())
104+
return yieldTerminator.emitOpError()
105+
<< "expected terminator to have as many operands as the parent op "
106+
"has results";
107+
108+
for (auto [i, termOperandType, resultType] : llvm::zip_equal(
109+
llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
110+
yieldTerminator->getOperands().getType(), getResultTypes())) {
111+
if (isa<transform::AnyParamType>(resultType))
112+
continue; // No type checking as result is of !transform.any_param type.
113+
auto paramResultType = dyn_cast<transform::ParamType>(resultType);
114+
if (!paramResultType)
115+
return emitOpError() << "result type #" << i
116+
<< " is not a !transform.param";
117+
Type wrappedResultType = paramResultType.getType();
118+
119+
if (isa<mlir::smt::IntType>(termOperandType)) {
120+
if (!isa<IntegerType>(wrappedResultType))
121+
return yieldTerminator.emitOpError()
122+
<< "the type of terminator operand #" << i
123+
<< " is !smt.int though the corresponding result type ("
124+
<< resultType
125+
<< ") of the parent op is not wrapping an integer type";
126+
} else if (isa<mlir::smt::BoolType>(termOperandType)) {
127+
auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
128+
if (!intResultType || intResultType.getWidth() != 1)
129+
return yieldTerminator.emitOpError()
130+
<< "the type of terminator operand #" << i
131+
<< " is !smt.bool though the corresponding result type ("
132+
<< resultType
133+
<< ") of the parent op is not wrapping i1 (i.e. bool)";
134+
} else if (auto bvOperandType =
135+
dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
136+
auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
137+
if (!intResultType ||
138+
intResultType.getWidth() != bvOperandType.getWidth())
139+
return yieldTerminator.emitOpError()
140+
<< "the type of terminator operand #" << i << " is "
141+
<< termOperandType << " though the corresponding result type ("
142+
<< resultType
143+
<< ") is not wrapping an integer type of the same bitwidth";
144+
}
145+
}
146+
54147
return success();
55148
}

mlir/python/mlir/dialects/transform/smt.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
class ConstrainParamsOp(ConstrainParamsOp):
2020
def __init__(
2121
self,
22+
results: Sequence[Type],
2223
params: Sequence[transform.AnyParamType],
2324
arg_types: Sequence[Type],
2425
loc=None,
@@ -27,6 +28,7 @@ def __init__(
2728
if len(params) != len(arg_types):
2829
raise ValueError(f"{params=} not same length as {arg_types=}")
2930
super().__init__(
31+
results,
3032
params,
3133
loc=loc,
3234
ip=ip,
@@ -36,3 +38,13 @@ def __init__(
3638
@property
3739
def body(self) -> Block:
3840
return self.regions[0].blocks[0]
41+
42+
43+
def constrain_params(
44+
results: Sequence[Type],
45+
params: Sequence[transform.AnyParamType],
46+
arg_types: Sequence[Type],
47+
loc=None,
48+
ip=None,
49+
):
50+
return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)

mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ module attributes {transform.with_named_sequence} {
55
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
66
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
77
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
8-
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
8+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
99
^bb0(%param_as_smt_var: !smt.int):
1010
%c4 = arith.constant 4 : i32
1111
// This is the kind of thing one might think works:
@@ -22,9 +22,54 @@ module attributes {transform.with_named_sequence} {
2222
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
2323
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
2424
// expected-error@below {{must have the same number of block arguments as operands}}
25-
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
25+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
2626
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
2727
}
2828
transform.yield
2929
}
3030
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: @results_not_one_to_one_with_vars
35+
module attributes {transform.with_named_sequence} {
36+
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
37+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
38+
transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
39+
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
40+
// expected-error@below {{expected terminator to have as many operands as the parent op has results}}
41+
smt.yield %param_as_smt_var : !smt.int
42+
}
43+
transform.yield
44+
}
45+
}
46+
47+
// -----
48+
49+
// CHECK-LABEL: @mismatched_type_bool
50+
module attributes {transform.with_named_sequence} {
51+
transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
52+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
53+
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
54+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
55+
^bb0(%param_as_smt_var: !smt.bool):
56+
smt.yield %param_as_smt_var : !smt.bool
57+
}
58+
transform.yield
59+
}
60+
}
61+
62+
// -----
63+
64+
// CHECK-LABEL: @mismatched_type_bitvector
65+
module attributes {transform.with_named_sequence} {
66+
transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
67+
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
68+
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
69+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
70+
^bb0(%param_as_smt_var: !smt.bv<8>):
71+
smt.yield %param_as_smt_var : !smt.bv<8>
72+
}
73+
transform.yield
74+
}
75+
}

mlir/test/Dialect/Transform/test-smt-extension.mlir

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
77
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
88

99
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
10-
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
10+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
1111
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
1212
^bb0(%param_as_smt_var: !smt.int):
1313
// CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
3131

3232
// -----
3333

34-
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
34+
// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
3535
module attributes {transform.with_named_sequence} {
36-
transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
36+
transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
3737
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
3838
%param_a = transform.param.constant 4 -> !transform.param<i64>
3939
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
40-
%param_b = transform.param.constant 16 -> !transform.param<i64>
40+
%param_b = transform.param.constant 32 -> !transform.param<i64>
4141

4242
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
43-
transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
43+
%divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
4444
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
4545
^bb0(%var_a: !smt.int, %var_b: !smt.int):
46+
// CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
47+
%divisor = smt.int.div %var_b, %var_a
4648
// CHECK: %[[C0:.*]] = smt.int.constant 0
4749
%c0 = smt.int.constant 0
4850
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
5153
%eq = smt.eq %remainder, %c0 : !smt.int
5254
// CHECK: smt.assert %[[EQ]]
5355
smt.assert %eq
56+
// CHECK: smt.yield %[[DIV]]
57+
smt.yield %divisor : !smt.int
5458
}
55-
// NB: from here can rely on that %param_a is a divisor of %param_b
59+
// NB: from here can rely on that %param_a is a divisor of %param_b and
60+
// that the relevant factor, 8, got associated to %divisor.
5661
transform.yield
5762
}
5863
}
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
6368
module attributes {transform.with_named_sequence} {
6469
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
6570
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
66-
%param_as_param = transform.param.constant true -> !transform.any_param
71+
%param_as_param = transform.param.constant true -> !transform.param<i1>
6772

6873
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
69-
transform.smt.constrain_params(%param_as_param) : !transform.any_param {
74+
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
7075
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
7176
^bb0(%param_as_smt_var: !smt.bool):
7277
// CHECK: %[[C0:.*]] = smt.int.constant 0

mlir/test/python/dialects/transform_smt_ext.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,26 +25,44 @@ def run(f):
2525
# CHECK-LABEL: TEST: testConstrainParamsOp
2626
@run
2727
def testConstrainParamsOp(target):
28-
dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
28+
c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
2929
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
30-
symbolic_value = transform.ParamConstantOp(
31-
transform.AnyParamType.get(), dummy_value
30+
symbolic_value_as_param = transform.ParamConstantOp(
31+
transform.AnyParamType.get(), c42_attr
3232
)
3333
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
3434
constrain_params = transform_smt.ConstrainParamsOp(
35-
[symbolic_value], [smt.IntType.get()]
35+
[], [symbolic_value_as_param], [smt.IntType.get()]
3636
)
3737
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
3838
with ir.InsertionPoint(constrain_params.body):
39+
symbolic_value_as_smt_var = constrain_params.body.arguments[0]
3940
# CHECK: %[[C0:.*]] = smt.int.constant 0
4041
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
4142
# CHECK: %[[C43:.*]] = smt.int.constant 43
4243
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
4344
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
44-
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
45+
lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var)
4546
# CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]]
46-
ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43)
47+
ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43)
4748
# CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]]
4849
bounded = smt.AndOp([lb, ub])
4950
# CHECK: smt.assert %[[BOUNDED:.*]]
5051
smt.AssertOp(bounded)
52+
smt.YieldOp([])
53+
54+
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
55+
compute_with_params = transform_smt.ConstrainParamsOp(
56+
[transform.ParamType.get(ir.IntegerType.get_signless(32))],
57+
[symbolic_value_as_param],
58+
[smt.IntType.get()],
59+
)
60+
# CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int):
61+
with ir.InsertionPoint(compute_with_params.body):
62+
symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
63+
# CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]]
64+
twice_symb = smt.IntAddOp(
65+
[symbolic_value_as_smt_var, symbolic_value_as_smt_var]
66+
)
67+
# CHECK: smt.yield %[[TWICE]]
68+
smt.YieldOp([twice_symb])

0 commit comments

Comments
 (0)