Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
Pure,
Terminator,
ReturnLike,
ParentOneOf<["smt::SolverOp", "smt::CheckOp",
"smt::ForallOp", "smt::ExistsOp"]>,
]> {
let summary = "terminator operation for various regions of SMT operations";
let arguments = (ins Variadic<AnyType>:$values);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,28 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
NoTerminator
SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
]> {
let cppNamespace = [{ mlir::transform::smt }];

let summary = "Express contraints on params interpreted as symbolic values";
let description = [{
Allows expressing constraints on params using the SMT dialect.

Each Transform dialect param provided as an operand has a corresponding
Each Transform-dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
these arguments as operands.
these params-as-SMT-vars as operands, thereby expressing relevant
constraints on their allowed values.

Computations w.r.t. passed-in params can also be expressed through the
region's SMT-ops. Namely, the constraints express relationships to other
SMT-variables which can then be yielded from the region (with `smt.yield`).

The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
op succeeds.
op succeeds and any one satisfying assignment is used to map the
SMT-variables yielded in the region to `transform.param`s.

---

Expand All @@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
}];

let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
let results = (outs Variadic<TransformParamTypeInterface>:$results);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"`(` $params `)` attr-dict `:` type(operands) $body";
"`(` $params `)` attr-dict `:` functional-type(operands, results) $body";

let hasVerifier = 1;
}
Expand Down
83 changes: 80 additions & 3 deletions mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/SMT/IR/SMTOps.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"

using namespace mlir;

Expand All @@ -23,6 +23,7 @@ using namespace mlir;
void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
producesHandle(getResults(), effects);
}

DiagnosedSilenceableFailure
Expand All @@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
return emitSilenceableFailure(getLoc())
<< "op does not have interpreted semantics yet";
}

LogicalResult transform::smt::ConstrainParamsOp::verify() {
auto yieldTerminator =
dyn_cast<mlir::smt::YieldOp>(getRegion().front().back());
if (!yieldTerminator)
return emitOpError() << "expected '"
<< mlir::smt::YieldOp::getOperationName()
<< "' as terminator";
Comment on lines +48 to +51
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen as the ODS-generated verifier should be verifying this. Try if you can trigger this specific error message and, if not, remove this and turn the dyn_cast above into a direct cast.

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the dyn_cast to auto yieldTerminator = cast<mlir::smt::YieldOp>(getRegion().front().back()); (and removing the check) does make it possible for me to crash on the cast. Either by having the wrong terminator, e.g. transform.yield or, using the Python API, I can construct the op without its region having a terminator as the last op. As an example:

    compute_with_params = transform_smt.ConstrainParamsOp(
        [transform.ParamType.get(ir.IntegerType.get_signless(32))],
        [symbolic_value_as_param],
        [smt.IntType.get()],
    )
    with ir.InsertionPoint(compute_with_params.body):
        symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
        twice_symb = smt.IntAddOp(
            [symbolic_value_as_smt_var, symbolic_value_as_smt_var]
        )

this then yields the following at runtime:

python: PATH_TO_REPO/llvm/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From &) [To = mlir::smt::YieldOp, From = mlir::Operation]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment the only relevant Trait/Interface on the op is SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">. I haven't yet been able to traceback to how this triggers/is supposed to trigger the right verifier.

Should I be using a different trait?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Thought to note: I can't get the code to crash on getRegion().front().back() when I supply an op with 1) a region with an empty block, or 2) a region with no blocks. There ODS-verifiers properly catch the issue: error: 'transform.smt.constrain_params' op expects a non-empty block and error: 'transform.smt.constrain_params' op region #0 ('body') failed to verify constraint: region with 1 blocks.)

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the .cpp.inc, the op's verifyInvariants() checks types on operands and results and the only thing it does for the region is:

static ::llvm::LogicalResult __mlir_ods_local_region_constraint_SMTExtensionOps1(
    ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
    unsigned regionIndex) {
  if (!((::llvm::hasNItems(region, 1)))) {
    return op->emitOpError("region #") << regionIndex
        << (regionName.empty() ? " " : " ('" + regionName + "') ")
        << "failed to verify constraint: region with 1 blocks";
  }
  return ::mlir::success();
}

As far as I can tell, the line ensureTerminator(*bodyRegion, parser.getBuilder(), result.location); in the op's parser is due to SingleBlockImplicitTerminator<"::mlir::smt::YieldOp"> though there's no verification that the op instance - when it's not constructed by the parser - has this terminator.


auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc,
Type paramType, StringRef paramDesc,
auto *atOp) -> InFlightDiagnostic {
if (!isa<mlir::smt::BoolType, mlir::smt::IntType, mlir::smt::BitVectorType>(
smtType))
return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx
<< " is expected to be either a !smt.bool, a "
"!smt.int, or a !smt.bv";

assert(isa<TransformParamTypeInterface>(paramType) &&
"ODS specifies params' type should implement param interface");
if (isa<transform::AnyParamType>(paramType))
return {}; // No further checks can be done.

// NB: This cast must succeed as long as the only implementors of
// TransformParamTypeInterface are AnyParamType and ParamType.
Type typeWrappedByParam = cast<ParamType>(paramType).getType();

if (isa<mlir::smt::IntType>(smtType)) {
if (!isa<IntegerType>(typeWrappedByParam))
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx
<< " is !smt.int though the corresponding " << paramDesc
<< " type (" << paramType << ") is not wrapping an integer type";
} else if (isa<mlir::smt::BoolType>(smtType)) {
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
if (!wrappedIntType || wrappedIntType.getWidth() != 1)
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx
<< " is !smt.bool though the corresponding " << paramDesc
<< " type (" << paramType << ") is not wrapping i1";
} else if (auto bvSmtType = dyn_cast<mlir::smt::BitVectorType>(smtType)) {
auto wrappedIntType = dyn_cast<IntegerType>(typeWrappedByParam);
if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth())
return atOp->emitOpError()
<< "the type of " << smtDesc << " #" << idx << " is " << smtType
<< " though the corresponding " << paramDesc << " type ("
<< paramType
<< ") is not wrapping an integer type of the same bitwidth";
}

return {};
};

if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");

for (auto [idx, operandType, blockArgType] :
llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) {
InFlightDiagnostic typeCheckResult =
checkTypes(idx, blockArgType, "block arg", operandType, "operand",
/*atOp=*/this);
if (LogicalResult(typeCheckResult).failed())
return typeCheckResult;
}

for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}

if (yieldTerminator->getNumOperands() != getNumResults())
return yieldTerminator.emitOpError()
<< "expected terminator to have as many operands as the parent op "
"has results";

for (auto [idx, termOperandType, resultType] : llvm::enumerate(
yieldTerminator->getOperands().getType(), getResultTypes())) {
InFlightDiagnostic typeCheckResult =
checkTypes(idx, termOperandType, "terminator operand",
cast<transform::ParamType>(resultType), "result",
/*atOp=*/&yieldTerminator);
if (LogicalResult(typeCheckResult).failed())
return typeCheckResult;
}

return success();
}
12 changes: 12 additions & 0 deletions mlir/python/mlir/dialects/transform/smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
self,
results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
Expand All @@ -27,6 +28,7 @@ def __init__(
if len(params) != len(arg_types):
raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
results,
params,
loc=loc,
ip=ip,
Expand All @@ -36,3 +38,13 @@ def __init__(
@property
def body(self) -> Block:
return self.regions[0].blocks[0]


def constrain_params(
results: Sequence[Type],
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
ip=None,
):
return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
116 changes: 111 additions & 5 deletions mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics

// CHECK-LABEL: @incorrect terminator
module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{op expected 'smt.yield' as terminator}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
transform.yield
}
transform.yield
}
}

// -----

// CHECK-LABEL: @operands_not_one_to_one_with_vars
module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{must have the same number of block arguments as operands}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
}
transform.yield
}
}

// -----

// CHECK-LABEL: @constraint_not_using_smt_ops
module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
Expand All @@ -17,13 +46,90 @@ module attributes {transform.with_named_sequence} {

// -----

// CHECK-LABEL: @operands_not_one_to_one_with_vars
// CHECK-LABEL: @results_not_one_to_one_with_vars
module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{must have the same number of block arguments as operands}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
// expected-error@below {{expected terminator to have as many operands as the parent op has results}}
smt.yield %param_as_smt_var : !smt.int
}
transform.yield
}
}

// -----

// CHECK-LABEL: @non_smt_type_block_args
module attributes {transform.with_named_sequence} {
transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
// expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i8>) {
^bb0(%param_as_smt_var: !transform.param<i8>):
smt.yield %param_as_smt_var : !transform.param<i8>
}
transform.yield
}
}


// -----

// CHECK-LABEL: @mismatched_arg_type_bool
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bool):
smt.yield %param_as_smt_var : !smt.bool
}
transform.yield
}
}

// -----

// CHECK-LABEL: @mismatched_arg_type_bitvector
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bv<8>):
smt.yield %param_as_smt_var : !smt.bv<8>
}
transform.yield
}
}

// -----

// CHECK-LABEL: @mismatched_result_type_bool
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 1 -> !transform.param<i1>
transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bool):
// expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param<i64>') is not wrapping i1}}
smt.yield %param_as_smt_var : !smt.bool
}
transform.yield
}
}

// -----

// CHECK-LABEL: @mismatched_result_type_bitvector
module attributes {transform.with_named_sequence} {
transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i8>
transform.smt.constrain_params(%param_as_param) : (!transform.param<i8>) -> (!transform.param<i64>) {
^bb0(%param_as_smt_var: !smt.bv<8>):
// expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
smt.yield %param_as_smt_var : !smt.bv<8>
}
transform.yield
}
Expand Down
Loading