Skip to content

Conversation

rolfmorel
Copy link
Contributor

By allowing transform.smt.constrain_params's region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed.

The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.

@llvmbot
Copy link
Member

llvmbot commented Sep 26, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

By allowing transform.smt.constrain_params's region to yield SMT-vars, op instances can declare relationships, through constraints, on incoming params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it possible to declare that computations on params should be performed.

The semantics are that the yielded SMT-vars should be from any valid satisfying assignment/model of the constraints in the region.


Patch is 21.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160895.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SMT/IR/SMTOps.td (-2)
  • (modified) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h (+1)
  • (modified) mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td (+12-5)
  • (modified) mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp (+96-3)
  • (modified) mlir/python/mlir/dialects/transform/smt.py (+12)
  • (modified) mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir (+47-2)
  • (modified) mlir/test/Dialect/Transform/test-smt-extension.mlir (+13-8)
  • (modified) mlir/test/python/dialects/transform_smt_ext.py (+24-6)
diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
index 3143ab7de1b14..99b22e5609c74 100644
--- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
+++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td
@@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [
   Pure,
   Terminator,
   ReturnLike,
-  ParentOneOf<["smt::SolverOp", "smt::CheckOp",
-               "smt::ForallOp", "smt::ExistsOp"]>,
 ]> {
   let summary = "terminator operation for various regions of SMT operations";
   let arguments = (ins Variadic<AnyType>:$values);
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
index fc69b039f24ff..f6353a995d747 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
@@ -10,6 +10,7 @@
 #define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
 #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
index b987cb31e54bb..9d9783aa66ed9 100644
--- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
+++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td
@@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   DeclareOpInterfaceMethods<TransformOpInterface>,
   DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
-  NoTerminator
+  SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">
 ]> {
   let cppNamespace = [{ mlir::transform::smt }];
 
@@ -24,14 +24,20 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   let description = [{
     Allows expressing constraints on params using the SMT dialect.
 
-    Each Transform dialect param provided as an operand has a corresponding
+    Each Transform-dialect param provided as an operand has a corresponding
     argument of SMT-type in the region. The SMT-Dialect ops in the region use
-    these arguments as operands.
+    these params-as-SMT-vars as operands, thereby expressing relevant
+    constraints on their allowed values.
+
+    Computations w.r.t. passed-in params can also be expressed through the
+    region's SMT-ops. Namely, the constraints express relationships to other
+    SMT-variables which can then be yielded from the region (with `smt.yield`).
 
     The semantics of this op is that all the ops in the region together express
     a constraint on the params-interpreted-as-smt-vars. The op fails in case the
     expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
-    op succeeds.
+    op succeeds and any one satisfying assignment is used to map the
+    SMT-variables yielded in the region to `transform.param`s.
 
     ---
 
@@ -42,9 +48,10 @@ def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
   }];
 
   let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
+  let results = (outs Variadic<TransformParamTypeInterface>:$results);
   let regions = (region SizedRegion<1>:$body);
   let assemblyFormat =
-      "`(` $params `)` attr-dict `:` type(operands) $body";
+      "`(` $params `)` attr-dict `:` functional-type(operands, results) $body";
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
index 8e7af05353de7..d85268da2ad5d 100644
--- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
@@ -8,8 +8,8 @@
 
 #include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
 #include "mlir/Dialect/SMT/IR/SMTDialect.h"
-#include "mlir/Dialect/Transform/IR/TransformOps.h"
-#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
+#include "mlir/Dialect/SMT/IR/SMTOps.h"
+#include "mlir/Dialect/Transform/IR/TransformTypes.h"
 
 using namespace mlir;
 
@@ -23,6 +23,7 @@ using namespace mlir;
 void transform::smt::ConstrainParamsOp::getEffects(
     SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
   onlyReadsHandle(getParamsMutable(), effects);
+  producesHandle(getResults(), effects);
 }
 
 DiagnosedSilenceableFailure
@@ -37,19 +38,111 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
   //       and allow for users to attach their own implementation, which would,
   //       e.g., translate the ops to SMTLIB and hand that over to the user's
   //       favourite solver. This requires changes to the dialect's verifier.
-  return emitDefiniteFailure() << "op does not have interpreted semantics yet";
+  return emitSilenceableFailure(getLoc())
+         << "op does not have interpreted semantics yet";
 }
 
 LogicalResult transform::smt::ConstrainParamsOp::verify() {
+  auto yieldTerminator =
+      llvm::dyn_cast_if_present<mlir::smt::YieldOp>(getRegion().front().back());
+  if (!yieldTerminator)
+    return emitOpError() << "expected '"
+                         << mlir::smt::YieldOp::getOperationName()
+                         << "' as terminator";
+
   if (getOperands().size() != getBody().getNumArguments())
     return emitOpError(
         "must have the same number of block arguments as operands");
 
+  for (auto [i, operandType, blockArgType] :
+       llvm::zip_equal(llvm::seq<unsigned>(0, getBody().getNumArguments()),
+                       getOperandTypes(), getBody().getArgumentTypes())) {
+    if (isa<transform::AnyParamType>(operandType))
+      continue; // No type checking as operand is of !transform.any_param type.
+    auto paramOperandType = dyn_cast<transform::ParamType>(operandType);
+    if (!paramOperandType)
+      return emitOpError() << "operand type #" << i
+                           << " is not a !transform.param";
+    Type wrappedOperandType = paramOperandType.getType();
+
+    if (isa<mlir::smt::IntType>(blockArgType)) {
+      if (!isa<IntegerType>(paramOperandType.getType()))
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.int though the corresponding operand type ("
+               << operandType << ") is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType || intOperandType.getWidth() != 1)
+        return emitOpError()
+               << "the type of block arg #" << i
+               << " is !smt.bool though the corresponding operand type ("
+               << operandType << ") is not wrapping i1 (i.e. bool)";
+    } else if (auto bvBlockArgType =
+                   dyn_cast<mlir::smt::BitVectorType>(blockArgType)) {
+      auto intOperandType = dyn_cast<IntegerType>(wrappedOperandType);
+      if (!intOperandType ||
+          intOperandType.getWidth() != bvBlockArgType.getWidth())
+        return emitOpError()
+               << "the type of block arg #" << i << " is " << blockArgType
+               << " though the corresponding operand type (" << operandType
+               << ") is not wrapping an integer type of the same bitwidth";
+    }
+  }
+
   for (auto &op : getBody().getOps()) {
     if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
       return emitOpError(
           "ops contained in region should belong to SMT-dialect");
   }
 
+  if (getOperands().size() != getBody().getNumArguments())
+    return emitOpError(
+        "must have the same number of block arguments as operands");
+
+  if (yieldTerminator->getNumOperands() != getNumResults())
+    return yieldTerminator.emitOpError()
+           << "expected terminator to have as many operands as the parent op "
+              "has results";
+
+  for (auto [i, termOperandType, resultType] : llvm::zip_equal(
+           llvm::seq<unsigned>(0, yieldTerminator->getNumOperands()),
+           yieldTerminator->getOperands().getType(), getResultTypes())) {
+    if (isa<transform::AnyParamType>(resultType))
+      continue; // No type checking as result is of !transform.any_param type.
+    auto paramResultType = dyn_cast<transform::ParamType>(resultType);
+    if (!paramResultType)
+      return emitOpError() << "result type #" << i
+                           << " is not a !transform.param";
+    Type wrappedResultType = paramResultType.getType();
+
+    if (isa<mlir::smt::IntType>(termOperandType)) {
+      if (!isa<IntegerType>(wrappedResultType))
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.int though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping an integer type";
+    } else if (isa<mlir::smt::BoolType>(termOperandType)) {
+      auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+      if (!intResultType || intResultType.getWidth() != 1)
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i
+               << " is !smt.bool though the corresponding result type ("
+               << resultType
+               << ") of the parent op is not wrapping i1 (i.e. bool)";
+    } else if (auto bvOperandType =
+                   dyn_cast<mlir::smt::BitVectorType>(termOperandType)) {
+      auto intResultType = dyn_cast<IntegerType>(wrappedResultType);
+      if (!intResultType ||
+          intResultType.getWidth() != bvOperandType.getWidth())
+        return yieldTerminator.emitOpError()
+               << "the type of terminator operand #" << i << " is "
+               << termOperandType << " though the corresponding result type ("
+               << resultType
+               << ") is not wrapping an integer type of the same bitwidth";
+    }
+  }
+
   return success();
 }
diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py
index 1f0b7f066118c..af88fffcd3bba 100644
--- a/mlir/python/mlir/dialects/transform/smt.py
+++ b/mlir/python/mlir/dialects/transform/smt.py
@@ -19,6 +19,7 @@
 class ConstrainParamsOp(ConstrainParamsOp):
     def __init__(
         self,
+        results: Sequence[Type],
         params: Sequence[transform.AnyParamType],
         arg_types: Sequence[Type],
         loc=None,
@@ -27,6 +28,7 @@ def __init__(
         if len(params) != len(arg_types):
             raise ValueError(f"{params=} not same length as {arg_types=}")
         super().__init__(
+            results,
             params,
             loc=loc,
             ip=ip,
@@ -36,3 +38,13 @@ def __init__(
     @property
     def body(self) -> Block:
         return self.regions[0].blocks[0]
+
+
+def constrain_params(
+    results: Sequence[Type],
+    params: Sequence[transform.AnyParamType],
+    arg_types: Sequence[Type],
+    loc=None,
+    ip=None,
+):
+    return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip)
diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
index 314b8d493c5d4..4e365fa2dbaf9 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
@@ -5,7 +5,7 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
     // expected-error@below {{ops contained in region should belong to SMT-dialect}}
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       ^bb0(%param_as_smt_var: !smt.int):
       %c4 = arith.constant 4 : i32
       // This is the kind of thing one might think works:
@@ -22,9 +22,54 @@ module attributes {transform.with_named_sequence} {
   transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
     // expected-error@below {{must have the same number of block arguments as operands}}
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
     }
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: @results_not_one_to_one_with_vars
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param<i64>, !transform.param<i64>) -> () {
+      ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
+      // expected-error@below {{expected terminator to have as many operands as the parent op has results}}
+      smt.yield %param_as_smt_var : !smt.int
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bool
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @mismatched_type_bool(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param<i64>') is not wrapping i1 (i.e. bool)}}
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+      ^bb0(%param_as_smt_var: !smt.bool):
+      smt.yield %param_as_smt_var : !smt.bool
+    }
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: @mismatched_type_bitvector
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @mismatched_type_bitvector(%arg0: !transform.any_op {transform.readonly}) {
+    %param_as_param = transform.param.constant 42 -> !transform.param<i64>
+    // expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param<i64>') is not wrapping an integer type of the same bitwidth}}
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> (!transform.param<i64>) {
+      ^bb0(%param_as_smt_var: !smt.bv<8>):
+      smt.yield %param_as_smt_var : !smt.bv<8>
+    }
+    transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir
index 29d15175ae4ec..6cc41dd52473e 100644
--- a/mlir/test/Dialect/Transform/test-smt-extension.mlir
+++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir
@@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} {
     %param_as_param = transform.param.constant 42 -> !transform.param<i64>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
-    transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i64>) -> () {
       // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
       ^bb0(%param_as_smt_var: !smt.int):
       // CHECK: %[[C0:.*]] = smt.int.constant 0
@@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
-// CHECK-LABEL: @schedule_with_constraint_on_multiple_params
+// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value
 module attributes {transform.with_named_sequence} {
-  transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) {
+  transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) {
     // CHECK: %[[PARAM_A:.*]] = transform.param.constant
     %param_a = transform.param.constant 4 -> !transform.param<i64>
     // CHECK: %[[PARAM_B:.*]] = transform.param.constant
-    %param_b = transform.param.constant 16 -> !transform.param<i64>
+    %param_b = transform.param.constant 32 -> !transform.param<i64>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]])
-    transform.smt.constrain_params(%param_a, %param_b) : !transform.param<i64>, !transform.param<i64> {
+    %divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param<i64>, !transform.param<i64>) -> (!transform.param<i64>) {
       // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int):
       ^bb0(%var_a: !smt.int, %var_b: !smt.int):
+      // CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]]
+      %divisor = smt.int.div %var_b, %var_a
       // CHECK: %[[C0:.*]] = smt.int.constant 0
       %c0 = smt.int.constant 0
       // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]]
@@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} {
       %eq = smt.eq %remainder, %c0 : !smt.int
       // CHECK: smt.assert %[[EQ]]
       smt.assert %eq
+      // CHECK: smt.yield %[[DIV]]
+      smt.yield %divisor : !smt.int
     }
-    // NB: from here can rely on that %param_a is a divisor of %param_b
+    // NB: from here can rely on that %param_a is a divisor of %param_b and
+    //     that the relevant factor, 8, got associated to %divisor.
     transform.yield
   }
 }
@@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) {
     // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
-    %param_as_param = transform.param.constant true -> !transform.any_param
+    %param_as_param = transform.param.constant true -> !transform.param<i1>
 
     // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
-    transform.smt.constrain_params(%param_as_param) : !transform.any_param {
+    transform.smt.constrain_params(%param_as_param) : (!transform.param<i1>) -> () {
       // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool):
       ^bb0(%param_as_smt_var: !smt.bool):
       // CHECK: %[[C0:.*]] = smt.int.constant 0
diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py
index 3692fd92344a6..e28c56f277439 100644
--- a/mlir/test/python/dialects/transform_smt_ext.py
+++ b/mlir/test/python/dialects/transform_smt_ext.py
@@ -25,26 +25,44 @@ def run(f):
 # CHECK-LABEL: TEST: testConstrainParamsOp
 @run
 def testConstrainParamsOp(target):
-    dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
+    c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42)
     # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant
-    symbolic_value = transform.ParamConstantOp(
-        transform.AnyParamType.get(), dummy_value
+    symbolic_value_as_param = transform.ParamConstantOp(
+        transform.AnyParamType.get(), c42_attr
     )
     # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]])
     constrain_params = transform_smt.ConstrainParamsOp(
-        [symbolic_value], [smt.IntType.get()]
+        [], [symbolic_value_as_param], [smt.IntType.get()]
     )
     # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int):
     with ir.InsertionPoint(constrain_params.body):
+        symbolic_value_as_smt_var = constrain_params.body.arguments[0]
         # CHECK: %[[C0:.*]] = smt.int.constant 0
         c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0))
         # CHECK: %[[C43:.*]] = smt.int.constant 43
         c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43))
         # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]]
-        lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0])
+        lb = smt.IntCmpOp(smt.IntPredicate....
[truncated]

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. Please make sure we have tests for every diagnostic, which may also catch cases that check an already-checked property and should be turned into assertions.

Comment on lines +48 to +51
if (!yieldTerminator)
return emitOpError() << "expected '"
<< mlir::smt::YieldOp::getOperationName()
<< "' as terminator";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should never happen as the ODS-generated verifier should be verifying this. Try if you can trigger this specific error message and, if not, remove this and turn the dyn_cast above into a direct cast.

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the dyn_cast to auto yieldTerminator = cast<mlir::smt::YieldOp>(getRegion().front().back()); (and removing the check) does make it possible for me to crash on the cast. Either by having the wrong terminator, e.g. transform.yield or, using the Python API, I can construct the op without its region having a terminator as the last op. As an example:

    compute_with_params = transform_smt.ConstrainParamsOp(
        [transform.ParamType.get(ir.IntegerType.get_signless(32))],
        [symbolic_value_as_param],
        [smt.IntType.get()],
    )
    with ir.InsertionPoint(compute_with_params.body):
        symbolic_value_as_smt_var = compute_with_params.body.arguments[0]
        twice_symb = smt.IntAddOp(
            [symbolic_value_as_smt_var, symbolic_value_as_smt_var]
        )

this then yields the following at runtime:

python: PATH_TO_REPO/llvm/include/llvm/Support/Casting.h:572: decltype(auto) llvm::cast(From &) [To = mlir::smt::YieldOp, From = mlir::Operation]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment the only relevant Trait/Interface on the op is SingleBlockImplicitTerminator<"::mlir::smt::YieldOp">. I haven't yet been able to traceback to how this triggers/is supposed to trigger the right verifier.

Should I be using a different trait?

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Thought to note: I can't get the code to crash on getRegion().front().back() when I supply an op with 1) a region with an empty block, or 2) a region with no blocks. There ODS-verifiers properly catch the issue: error: 'transform.smt.constrain_params' op expects a non-empty block and error: 'transform.smt.constrain_params' op region #0 ('body') failed to verify constraint: region with 1 blocks.)

Copy link
Contributor Author

@rolfmorel rolfmorel Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the .cpp.inc, the op's verifyInvariants() checks types on operands and results and the only thing it does for the region is:

static ::llvm::LogicalResult __mlir_ods_local_region_constraint_SMTExtensionOps1(
    ::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName,
    unsigned regionIndex) {
  if (!((::llvm::hasNItems(region, 1)))) {
    return op->emitOpError("region #") << regionIndex
        << (regionName.empty() ? " " : " ('" + regionName + "') ")
        << "failed to verify constraint: region with 1 blocks";
  }
  return ::mlir::success();
}

As far as I can tell, the line ensureTerminator(*bodyRegion, parser.getBuilder(), result.location); in the op's parser is due to SingleBlockImplicitTerminator<"::mlir::smt::YieldOp"> though there's no verification that the op instance - when it's not constructed by the parser - has this terminator.

@rolfmorel
Copy link
Contributor Author

rolfmorel commented Oct 9, 2025

Thanks for the review, @ftynse! I addressed your comments.

Note that my current understanding is that the terminator does need to be checked in verify(). Also note that the following error case is currently not triggerable as !transform.param<T>'s verifier enforces that T must be an integer type. When this is eventually is relaxed, the following check will be triggerable and useful.

    if (isa<mlir::smt::IntType>(smtType)) {
      if (!isa<IntegerType>(typeWrappedByParam))
        return op->emitOpError()
               << "the type of " << smtDesc << " #" << idx
               << " is !smt.int though the corresponding " << paramDesc
               << " type (" << paramType << ") is not wrapping an integer type";

Will land in a day or so, unless I hear otherwise.

By allowing `transform.smt.constrain_params`'s region to yield SMT vars,
we op instances declare relationships, through constraints, on incoming
params-as-SMT-vars and outgoing SMT-vars-as-params. This makes it
possible to declare that computations on params should be performed.

The semantics are that the yielded SMT-vars should be from any valid
satisfying assignment/model of the constraints in the region.
@rolfmorel rolfmorel force-pushed the smt.constrain_params branch from 778091d to b060217 Compare October 18, 2025 23:16
@rolfmorel rolfmorel enabled auto-merge (squash) October 18, 2025 23:17
@rolfmorel rolfmorel merged commit 9351ad6 into llvm:main Oct 18, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants