-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][Transform][SMT] Allow for declarative computations in schedules #160895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesBy allowing The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region. Patch is 21.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160895.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index 3143ab7de1b14..99b22e5609c74 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
- ParentOneOf<["smt::SolverOp", "smt::CheckOp",
- "smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index fc69b039f24ff..f6353a995d747 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
#include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
index b987cb31e54bb..9d9783aa66ed9 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
- NoTerminator
+ SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
]> {
let cppNamespace = [{ mlir::transform::smt }];
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
let description = [{
Allows expressing constraints on params using the SMT dialect.
- Each Transform dialect param provided as an operand has a corresponding
+ Each Transform-dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
- these arguments as operands.
+ these params-as-SMT-vars as operands, thereby expressing relevant
+ constraints on their allowed values.
+
+ Computations w.r.t. passed-in params can also be expressed through the
+ region's SMT-ops. Namely, the constraints express relationships to other
+ SMT-variables which can then be yielded from the region (with `smt.yield`).
The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
- op succeeds.
+ op succeeds and any one satisfying assignment is used to map the
+ SMT-variables yielded in the region to `transform.param`s.
---
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
}];
let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+ let results = (outs Variadic<TransformParamTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
- "`(` $params `)` attr-dict `:` type(operands) $body";
+ "`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05353de7..d85268da2ad5d 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
using namespace mlir;
@@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
+ producesHandle(getResults(), effects);
}
DiagnosedSilenceableFailure
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
- return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+ return emitSilenceableFailure(getLoc())
+ << "op does not have interpreted semantics yet";
}
LogicalResult transform::smt::ConstrainParamsOp::verify() {
+ auto yieldTerminator =
+ llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+ if (!yieldTerminator)
+ return emitOpError() << "expected '"
+ << mlir::smt::YieldOp::getOperationName()
+ << "' as terminator";
+
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");
+ for (auto [i, operandType, blockArgType] :
+ llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
+ getOperandTypes(), getBody().getArgumentTypes())) {
+ if (isa<transform::AnyParamType>(operandType))
+ continue; // No type checking as operand is of !transform.any_param type.
+ auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
+ if (!paramOperandType)
+ return emitOpError() << "operand type #" << i
+ << " is not a !transform.param";
+ Type wrappedOperandType = paramOperandType.getType();
+
+ if (isa<mlir::smt::IntType>(blockArgType)) {
+ if (!isa<IntegerType>(paramOperandType.getType()))
+ return emitOpError()
+ << "the type of block arg #" << i
+ << " is !smt.int though the corresponding operand type ("
+ << operandType << ") is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(blockArgType)) {
+ auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+ if (!intOperandType || intOperandType.getWidth() != 1)
+ return emitOpError()
+ << "the type of block arg #" << i
+ << " is !smt.bool though the corresponding operand type ("
+ << operandType << ") is not wrapping i1 (i.e. bool)";
+ } else if (auto bvBlockArgType =
+ dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
+ auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+ if (!intOperandType ||
+ intOperandType.getWidth() != bvBlockArgType.getWidth())
+ return emitOpError()
+ << "the type of block arg #" << i << " is " << blockArgType
+ << " though the corresponding operand type (" << operandType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+ }
+
for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}
+ if (getOperands().size() != getBody().getNumArguments())
+ return emitOpError(
+ "must have the same number of block arguments as operands");
+
+ if (yieldTerminator->getNumOperands() != getNumResults())
+ return yieldTerminator.emitOpError()
+ << "expected terminator to have as many operands as the parent op "
+ "has results";
+
+ for (auto [i, termOperandType, resultType] : llvm::zip_equal(
+ llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+ yieldTerminator->getOperands().getType(), getResultTypes())) {
+ if (isa<transform::AnyParamType>(resultType))
+ continue; // No type checking as result is of !transform.any_param type.
+ auto paramResultType = dyn_cast<transform::ParamType>(resultType);
+ if (!paramResultType)
+ return emitOpError() << "result type #" << i
+ << " is not a !transform.param";
+ Type wrappedResultType = paramResultType.getType();
+
+ if (isa<mlir::smt::IntType>(termOperandType)) {
+ if (!isa<IntegerType>(wrappedResultType))
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i
+ << " is !smt.int though the corresponding result type ("
+ << resultType
+ << ") of the parent op is not wrapping an integer type";
+ } else if (isa<mlir::smt::BoolType>(termOperandType)) {
+ auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+ if (!intResultType || intResultType.getWidth() != 1)
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i
+ << " is !smt.bool though the corresponding result type ("
+ << resultType
+ << ") of the parent op is not wrapping i1 (i.e. bool)";
+ } else if (auto bvOperandType =
+ dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
+ auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+ if (!intResultType ||
+ intResultType.getWidth() != bvOperandType.getWidth())
+ return yieldTerminator.emitOpError()
+ << "the type of terminator operand #" << i << " is "
+ << termOperandType << " though the corresponding result type ("
+ << resultType
+ << ") is not wrapping an integer type of the same bitwidth";
+ }
+ }
+
return success();
}
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index 1f0b7f066118c..af88fffcd3bba 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -19,6 +19,7 @@
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
self,
+ results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
@@ -27,6 +28,7 @@ def __init__(
if len(params) != len(arg_types):
raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
+ results,
params,
loc=loc,
ip=ip,
@@ -36,3 +38,13 @@ def __init__(
@property
def body(self) -> Block:
return self.regions[0].blocks[0]
+
+
+def constrain_params(
+ results: Sequence[Type],
+ params: Sequence[transform.AnyParamType],
+ arg_types: Sequence[Type],
+ loc=None,
+ ip=None,
+):
+ return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 314b8d493c5d4..4e365fa2dbaf9 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -5,7 +5,7 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
@@ -22,9 +22,54 @@ module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{must have the same number of block arguments as operands}}
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
}
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: @results_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
+ ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+ // expected-error@below {{expected terminator to have as many operands as the parent op has results}}
+ smt.yield %param_as_smt_var : !smt.int
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bool
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bool):
+ smt.yield %param_as_smt_var : !smt.bool
+ }
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bitvector
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+ %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+ // expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+ ^bb0(%param_as_smt_var: !smt.bv<8>):
+ smt.yield %param_as_smt_var : !smt.bv<8>
+ }
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
index 29d15175ae4ec..6cc41dd52473e 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
- transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
^bb0(%param_as_smt_var: !smt.int):
// CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
module attributes {transform.with_named_sequence} {
- transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+ transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_A:.*]] = transform.param.constant
%param_a = transform.param.constant 4 -> !transform.param<i64>
// CHECK: %[[PARAM_B:.*]] = transform.param.constant
- %param_b = transform.param.constant 16 -> !transform.param<i64>
+ %param_b = transform.param.constant 32 -> !transform.param<i64>
// CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
- transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+ %divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
// CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
^bb0(%var_a: !smt.int, %var_b: !smt.int):
+ // CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
+ %divisor = smt.int.div %var_b, %var_a
// CHECK: %[[C0:.*]] = smt.int.constant 0
%c0 = smt.int.constant 0
// CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
%eq = smt.eq %remainder, %c0 : !smt.int
// CHECK: smt.assert %[[EQ]]
smt.assert %eq
+ // CHECK: smt.yield %[[DIV]]
+ smt.yield %divisor : !smt.int
}
- // NB: from here can rely on that %param_a is a divisor of %param_b
+ // NB: from here can rely on that %param_a is a divisor of %param_b and
+ // that the relevant factor, 8, got associated to %divisor.
transform.yield
}
}
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
module attributes {transform.with_named_sequence} {
transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
// CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
- %param_as_param = transform.param.constant true -> !transform.any_param
+ %param_as_param = transform.param.constant true -> !transform.param<i1>
// CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
- transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+ transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
// CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
^bb0(%param_as_smt_var: !smt.bool):
// CHECK: %[[C0:.*]] = smt.int.constant 0
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
index 3692fd92344a6..e28c56f277439 100644
--- a/mlir/test/python/dialects/transform_smt_ext.py
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -25,26 +25,44 @@ def run(f):
# CHECK-LABEL: TEST: testConstrainParamsOp
@run
def testConstrainParamsOp(target):
- dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+ c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
# CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
- symbolic_value = transform.ParamConstantOp(
- transform.AnyParamType.get(), dummy_value
+ symbolic_value_as_param = transform.ParamConstantOp(
+ transform.AnyParamType.get(), c42_attr
)
# CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
constrain_params = transform_smt.ConstrainParamsOp(
- [symbolic_value], [smt.IntType.get()]
+ [], [symbolic_value_as_param], [smt.IntType.get()]
)
# CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
with ir.InsertionPoint(constrain_params.body):
+ symbolic_value_as_smt_var = constrain_params.body.arguments[0]
# CHECK: %[[C0:.*]] = smt.int.constant 0
c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
# CHECK: %[[C43:.*]] = smt.int.constant 43
c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
# CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
- lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+ lb = smt.IntCmpOp(smt.IntPredicate....
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM in general. Please make sure we have tests for every diagnostic, which may also catch cases that check an already-checked property and should be turned into assertions.
if (!yieldTerminator) | ||
return emitOpError() << "expected '" | ||
<< mlir::smt::YieldOp::getOperationName() | ||
<< "' as terminator"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should never happen as the ODS-generated verifier should be verifying this. Try if you can trigger this specific error message and, if not, remove this and turn the dyn_cast
above into a direct cast
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the dyn_cast
to auto yieldTerminator = cast<mlir::smt::YieldOp>(getRegion().front().back());
(and removing the check) does make it possible for me to crash on the cast. Either by having the wrong terminator, e.g. transform.yield
or, using the Python API, I can construct the op without its region having a terminator as the last op. As an example:
compute_with_params = transform_smt.ConstrainParamsOp(
[transform.ParamType.get(ir.IntegerType.get_signless(32))],
[symbolic_value_as_param],
[smt.IntType.get()],
)
with ir.InsertionPoint(compute_with_params.body):
symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
twice_symb = smt.IntAddOp(
[symbolic_value_as_smt_var, symbolic_value_as_smt_var]
)
this then yields the following at runtime:
python: PATH_TO_REPO/llvm/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From &) [To = mlir::smt::YieldOp, From = mlir::Operation]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the moment the only relevant Trait/Interface on the op is SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
. I haven't yet been able to traceback to how this triggers/is supposed to trigger the right verifier.
Should I be using a different trait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Thought to note: I can't get the code to crash on getRegion().front().back()
when I supply an op with 1) a region with an empty block, or 2) a region with no blocks. There ODS-verifiers properly catch the issue: error: 'transform.smt.constrain_params' op expects a non-empty block
and error: 'transform.smt.constrain_params' op region #0 ('body') failed to verify constraint: region with 1 blocks
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the .cpp.inc
, the op's verifyInvariants()
checks types on operands and results and the only thing it does for the region is:
static ::llvm::LogicalResult __mlir_ods_local_region_constraint_SMTExtensionOps1(
::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName,
unsigned regionIndex) {
if (!((::llvm::hasNItems(region, 1)))) {
return op->emitOpError("region #") << regionIndex
<< (regionName.empty() ? " " : " ('" + regionName + "') ")
<< "failed to verify constraint: region with 1 blocks";
}
return ::mlir::success();
}
As far as I can tell, the line ensureTerminator(*bodyRegion, parser.getBuilder(), result.location);
in the op's parser is due to SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
though there's no verification that the op instance - when it's not constructed by the parser - has this terminator.
Thanks for the review, @ftynse! I addressed your comments. Note that my current understanding is that the terminator does need to be checked in
Will land in a day or so, unless I hear otherwise. |
By allowing `transform.smt.constrain_params`'s region to yield SMT vars, we op instances declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed. The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.
778091d
to
b060217
Compare
By allowing
transform.smt.constrain_params
's region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed.The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.