diff --git a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td index 3143ab7de1b14..99b22e5609c74 100644 --- a/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td +++ b/mlir/include/mlir/Dialect/SMT/IR/SMTOps.td @@ -220,8 +220,6 @@ def YieldOp : SMTOp<"yield", [ Pure, Terminator, ReturnLike, - ParentOneOf<["smt::SolverOp", "smt::CheckOp", - "smt::ForallOp", "smt::ExistsOp"]>, ]> { let summary = "terminator operation for various regions of SMT operations"; let arguments = (ins Variadic:$values); diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h index fc69b039f24ff..f6353a995d747 100644 --- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" diff --git a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td index b987cb31e54bb..9d9783aa66ed9 100644 --- a/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td +++ b/mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td @@ -16,7 +16,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td" def ConstrainParamsOp : Op, DeclareOpInterfaceMethods, - NoTerminator + SingleBlockImplicitTerminator<"::mlir::smt::YieldOp"> ]> { let cppNamespace = [{ mlir::transform::smt }]; @@ -24,14 +24,20 @@ def ConstrainParamsOp : Op:$params); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$body); let assemblyFormat = - "`(` $params `)` attr-dict `:` type(operands) $body"; + "`(` $params `)` attr-dict `:` functional-type(operands, results) $body"; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp index 8e7af05353de7..abc131639fb3a 100644 --- a/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp @@ -8,8 +8,8 @@ #include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h" #include "mlir/Dialect/SMT/IR/SMTDialect.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h" +#include "mlir/Dialect/SMT/IR/SMTOps.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" using namespace mlir; @@ -23,6 +23,7 @@ using namespace mlir; void transform::smt::ConstrainParamsOp::getEffects( SmallVectorImpl &effects) { onlyReadsHandle(getParamsMutable(), effects); + producesHandle(getResults(), effects); } DiagnosedSilenceableFailure @@ -37,19 +38,95 @@ transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter, // and allow for users to attach their own implementation, which would, // e.g., translate the ops to SMTLIB and hand that over to the user's // favourite solver. This requires changes to the dialect's verifier. - return emitDefiniteFailure() << "op does not have interpreted semantics yet"; + return emitSilenceableFailure(getLoc()) + << "op does not have interpreted semantics yet"; } LogicalResult transform::smt::ConstrainParamsOp::verify() { + auto yieldTerminator = + dyn_cast(getRegion().front().back()); + if (!yieldTerminator) + return emitOpError() << "expected '" + << mlir::smt::YieldOp::getOperationName() + << "' as terminator"; + + auto checkTypes = [](size_t idx, Type smtType, StringRef smtDesc, + Type paramType, StringRef paramDesc, + auto *atOp) -> InFlightDiagnostic { + if (!isa( + smtType)) + return atOp->emitOpError() << "the type of " << smtDesc << " #" << idx + << " is expected to be either a !smt.bool, a " + "!smt.int, or a !smt.bv"; + + assert(isa(paramType) && + "ODS specifies params' type should implement param interface"); + if (isa(paramType)) + return {}; // No further checks can be done. + + // NB: This cast must succeed as long as the only implementors of + // TransformParamTypeInterface are AnyParamType and ParamType. + Type typeWrappedByParam = cast(paramType).getType(); + + if (isa(smtType)) { + if (!isa(typeWrappedByParam)) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx + << " is !smt.int though the corresponding " << paramDesc + << " type (" << paramType << ") is not wrapping an integer type"; + } else if (isa(smtType)) { + auto wrappedIntType = dyn_cast(typeWrappedByParam); + if (!wrappedIntType || wrappedIntType.getWidth() != 1) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx + << " is !smt.bool though the corresponding " << paramDesc + << " type (" << paramType << ") is not wrapping i1"; + } else if (auto bvSmtType = dyn_cast(smtType)) { + auto wrappedIntType = dyn_cast(typeWrappedByParam); + if (!wrappedIntType || wrappedIntType.getWidth() != bvSmtType.getWidth()) + return atOp->emitOpError() + << "the type of " << smtDesc << " #" << idx << " is " << smtType + << " though the corresponding " << paramDesc << " type (" + << paramType + << ") is not wrapping an integer type of the same bitwidth"; + } + + return {}; + }; + if (getOperands().size() != getBody().getNumArguments()) return emitOpError( "must have the same number of block arguments as operands"); + for (auto [idx, operandType, blockArgType] : + llvm::enumerate(getOperandTypes(), getBody().getArgumentTypes())) { + InFlightDiagnostic typeCheckResult = + checkTypes(idx, blockArgType, "block arg", operandType, "operand", + /*atOp=*/this); + if (LogicalResult(typeCheckResult).failed()) + return typeCheckResult; + } + for (auto &op : getBody().getOps()) { if (!isa(op.getDialect())) return emitOpError( "ops contained in region should belong to SMT-dialect"); } + if (yieldTerminator->getNumOperands() != getNumResults()) + return yieldTerminator.emitOpError() + << "expected terminator to have as many operands as the parent op " + "has results"; + + for (auto [idx, termOperandType, resultType] : llvm::enumerate( + yieldTerminator->getOperands().getType(), getResultTypes())) { + InFlightDiagnostic typeCheckResult = + checkTypes(idx, termOperandType, "terminator operand", + cast(resultType), "result", + /*atOp=*/&yieldTerminator); + if (LogicalResult(typeCheckResult).failed()) + return typeCheckResult; + } + return success(); } diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py index 1f0b7f066118c..af88fffcd3bba 100644 --- a/mlir/python/mlir/dialects/transform/smt.py +++ b/mlir/python/mlir/dialects/transform/smt.py @@ -19,6 +19,7 @@ class ConstrainParamsOp(ConstrainParamsOp): def __init__( self, + results: Sequence[Type], params: Sequence[transform.AnyParamType], arg_types: Sequence[Type], loc=None, @@ -27,6 +28,7 @@ def __init__( if len(params) != len(arg_types): raise ValueError(f"{params=} not same length as {arg_types=}") super().__init__( + results, params, loc=loc, ip=ip, @@ -36,3 +38,13 @@ def __init__( @property def body(self) -> Block: return self.regions[0].blocks[0] + + +def constrain_params( + results: Sequence[Type], + params: Sequence[transform.AnyParamType], + arg_types: Sequence[Type], + loc=None, + ip=None, +): + return ConstrainParamsOp(results, params, arg_types, loc=loc, ip=ip) diff --git a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir index 314b8d493c5d4..d91d69a756458 100644 --- a/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir +++ b/mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir @@ -1,11 +1,40 @@ // RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics +// CHECK-LABEL: @incorrect terminator +module attributes {transform.with_named_sequence} { + transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{op expected 'smt.yield' as terminator}} + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> () { + ^bb0(%param_as_smt_var: !smt.int): + transform.yield + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @operands_not_one_to_one_with_vars +module attributes {transform.with_named_sequence} { + transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{must have the same number of block arguments as operands}} + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> () { + ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int): + } + transform.yield + } +} + +// ----- + // CHECK-LABEL: @constraint_not_using_smt_ops module attributes {transform.with_named_sequence} { transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) { %param_as_param = transform.param.constant 42 -> !transform.param // expected-error@below {{ops contained in region should belong to SMT-dialect}} - transform.smt.constrain_params(%param_as_param) : !transform.param { + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> () { ^bb0(%param_as_smt_var: !smt.int): %c4 = arith.constant 4 : i32 // This is the kind of thing one might think works: @@ -17,13 +46,90 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: @operands_not_one_to_one_with_vars +// CHECK-LABEL: @results_not_one_to_one_with_vars module attributes {transform.with_named_sequence} { - transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) { + transform.named_sequence @results_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) { %param_as_param = transform.param.constant 42 -> !transform.param - // expected-error@below {{must have the same number of block arguments as operands}} - transform.smt.constrain_params(%param_as_param) : !transform.param { + transform.smt.constrain_params(%param_as_param, %param_as_param) : (!transform.param, !transform.param) -> () { ^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int): + // expected-error@below {{expected terminator to have as many operands as the parent op has results}} + smt.yield %param_as_smt_var : !smt.int + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @non_smt_type_block_args +module attributes {transform.with_named_sequence} { + transform.named_sequence @non_smt_type_block_args(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{the type of block arg #0 is expected to be either a !smt.bool, a !smt.int, or a !smt.bv}} + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> (!transform.param) { + ^bb0(%param_as_smt_var: !transform.param): + smt.yield %param_as_smt_var : !transform.param + } + transform.yield + } +} + + +// ----- + +// CHECK-LABEL: @mismatched_arg_type_bool +module attributes {transform.with_named_sequence} { + transform.named_sequence @mismatched_arg_type_bool(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{the type of block arg #0 is !smt.bool though the corresponding operand type ('!transform.param') is not wrapping i1}} + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> (!transform.param) { + ^bb0(%param_as_smt_var: !smt.bool): + smt.yield %param_as_smt_var : !smt.bool + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @mismatched_arg_type_bitvector +module attributes {transform.with_named_sequence} { + transform.named_sequence @mismatched_arg_type_bitvector(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + // expected-error@below {{the type of block arg #0 is '!smt.bv<8>' though the corresponding operand type ('!transform.param') is not wrapping an integer type of the same bitwidth}} + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> (!transform.param) { + ^bb0(%param_as_smt_var: !smt.bv<8>): + smt.yield %param_as_smt_var : !smt.bv<8> + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @mismatched_result_type_bool +module attributes {transform.with_named_sequence} { + transform.named_sequence @mismatched_result_type_bool(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 1 -> !transform.param + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> (!transform.param) { + ^bb0(%param_as_smt_var: !smt.bool): + // expected-error@below {{the type of terminator operand #0 is !smt.bool though the corresponding result type ('!transform.param') is not wrapping i1}} + smt.yield %param_as_smt_var : !smt.bool + } + transform.yield + } +} + +// ----- + +// CHECK-LABEL: @mismatched_result_type_bitvector +module attributes {transform.with_named_sequence} { + transform.named_sequence @mismatched_result_type_bitvector(%arg0: !transform.any_op {transform.readonly}) { + %param_as_param = transform.param.constant 42 -> !transform.param + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> (!transform.param) { + ^bb0(%param_as_smt_var: !smt.bv<8>): + // expected-error@below {{the type of terminator operand #0 is '!smt.bv<8>' though the corresponding result type ('!transform.param') is not wrapping an integer type of the same bitwidth}} + smt.yield %param_as_smt_var : !smt.bv<8> } transform.yield } diff --git a/mlir/test/Dialect/Transform/test-smt-extension.mlir b/mlir/test/Dialect/Transform/test-smt-extension.mlir index 29d15175ae4ec..6cc41dd52473e 100644 --- a/mlir/test/Dialect/Transform/test-smt-extension.mlir +++ b/mlir/test/Dialect/Transform/test-smt-extension.mlir @@ -7,7 +7,7 @@ module attributes {transform.with_named_sequence} { %param_as_param = transform.param.constant 42 -> !transform.param // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) - transform.smt.constrain_params(%param_as_param) : !transform.param { + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> () { // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int): ^bb0(%param_as_smt_var: !smt.int): // CHECK: %[[C0:.*]] = smt.int.constant 0 @@ -31,18 +31,20 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: @schedule_with_constraint_on_multiple_params +// CHECK-LABEL: @schedule_with_constraint_on_multiple_params_returning_computed_value module attributes {transform.with_named_sequence} { - transform.named_sequence @schedule_with_constraint_on_multiple_params(%arg0: !transform.any_op {transform.readonly}) { + transform.named_sequence @schedule_with_constraint_on_multiple_params_returning_computed_value(%arg0: !transform.any_op {transform.readonly}) { // CHECK: %[[PARAM_A:.*]] = transform.param.constant %param_a = transform.param.constant 4 -> !transform.param // CHECK: %[[PARAM_B:.*]] = transform.param.constant - %param_b = transform.param.constant 16 -> !transform.param + %param_b = transform.param.constant 32 -> !transform.param // CHECK: transform.smt.constrain_params(%[[PARAM_A]], %[[PARAM_B]]) - transform.smt.constrain_params(%param_a, %param_b) : !transform.param, !transform.param { + %divisor = transform.smt.constrain_params(%param_a, %param_b) : (!transform.param, !transform.param) -> (!transform.param) { // CHECK: ^bb{{.*}}(%[[VAR_A:.*]]: !smt.int, %[[VAR_B:.*]]: !smt.int): ^bb0(%var_a: !smt.int, %var_b: !smt.int): + // CHECK: %[[DIV:.*]] = smt.int.div %[[VAR_B]], %[[VAR_A]] + %divisor = smt.int.div %var_b, %var_a // CHECK: %[[C0:.*]] = smt.int.constant 0 %c0 = smt.int.constant 0 // CHECK: %[[REMAINDER:.*]] = smt.int.mod %[[VAR_B]], %[[VAR_A]] @@ -51,8 +53,11 @@ module attributes {transform.with_named_sequence} { %eq = smt.eq %remainder, %c0 : !smt.int // CHECK: smt.assert %[[EQ]] smt.assert %eq + // CHECK: smt.yield %[[DIV]] + smt.yield %divisor : !smt.int } - // NB: from here can rely on that %param_a is a divisor of %param_b + // NB: from here can rely on that %param_a is a divisor of %param_b and + // that the relevant factor, 8, got associated to %divisor. transform.yield } } @@ -63,10 +68,10 @@ module attributes {transform.with_named_sequence} { module attributes {transform.with_named_sequence} { transform.named_sequence @schedule_with_param_as_a_bool(%arg0: !transform.any_op {transform.readonly}) { // CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant - %param_as_param = transform.param.constant true -> !transform.any_param + %param_as_param = transform.param.constant true -> !transform.param // CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) - transform.smt.constrain_params(%param_as_param) : !transform.any_param { + transform.smt.constrain_params(%param_as_param) : (!transform.param) -> () { // CHECK: ^bb{{.*}}(%[[PARAM_AS_SMT_VAR:.*]]: !smt.bool): ^bb0(%param_as_smt_var: !smt.bool): // CHECK: %[[C0:.*]] = smt.int.constant 0 diff --git a/mlir/test/python/dialects/transform_smt_ext.py b/mlir/test/python/dialects/transform_smt_ext.py index 3692fd92344a6..e28c56f277439 100644 --- a/mlir/test/python/dialects/transform_smt_ext.py +++ b/mlir/test/python/dialects/transform_smt_ext.py @@ -25,26 +25,44 @@ def run(f): # CHECK-LABEL: TEST: testConstrainParamsOp @run def testConstrainParamsOp(target): - dummy_value = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) + c42_attr = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 42) # CHECK: %[[PARAM_AS_PARAM:.*]] = transform.param.constant - symbolic_value = transform.ParamConstantOp( - transform.AnyParamType.get(), dummy_value + symbolic_value_as_param = transform.ParamConstantOp( + transform.AnyParamType.get(), c42_attr ) # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) constrain_params = transform_smt.ConstrainParamsOp( - [symbolic_value], [smt.IntType.get()] + [], [symbolic_value_as_param], [smt.IntType.get()] ) # CHECK-NEXT: ^bb{{.*}}(%[[PARAM_AS_SMT_SYMB:.*]]: !smt.int): with ir.InsertionPoint(constrain_params.body): + symbolic_value_as_smt_var = constrain_params.body.arguments[0] # CHECK: %[[C0:.*]] = smt.int.constant 0 c0 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 0)) # CHECK: %[[C43:.*]] = smt.int.constant 43 c43 = smt.IntConstantOp(ir.IntegerAttr.get(ir.IntegerType.get_signless(32), 43)) # CHECK: %[[LB:.*]] = smt.int.cmp le %[[C0]], %[[PARAM_AS_SMT_SYMB]] - lb = smt.IntCmpOp(smt.IntPredicate.le, c0, constrain_params.body.arguments[0]) + lb = smt.IntCmpOp(smt.IntPredicate.le, c0, symbolic_value_as_smt_var) # CHECK: %[[UB:.*]] = smt.int.cmp le %[[PARAM_AS_SMT_SYMB]], %[[C43]] - ub = smt.IntCmpOp(smt.IntPredicate.le, constrain_params.body.arguments[0], c43) + ub = smt.IntCmpOp(smt.IntPredicate.le, symbolic_value_as_smt_var, c43) # CHECK: %[[BOUNDED:.*]] = smt.and %[[LB]], %[[UB]] bounded = smt.AndOp([lb, ub]) # CHECK: smt.assert %[[BOUNDED:.*]] smt.AssertOp(bounded) + smt.YieldOp([]) + + # CHECK: transform.smt.constrain_params(%[[PARAM_AS_PARAM]]) + compute_with_params = transform_smt.ConstrainParamsOp( + [transform.ParamType.get(ir.IntegerType.get_signless(32))], + [symbolic_value_as_param], + [smt.IntType.get()], + ) + # CHECK-NEXT: ^bb{{.*}}(%[[SMT_SYMB:.*]]: !smt.int): + with ir.InsertionPoint(compute_with_params.body): + symbolic_value_as_smt_var = compute_with_params.body.arguments[0] + # CHECK: %[[TWICE:.*]] = smt.int.add %[[SMT_SYMB]], %[[SMT_SYMB]] + twice_symb = smt.IntAddOp( + [symbolic_value_as_smt_var, symbolic_value_as_smt_var] + ) + # CHECK: smt.yield %[[TWICE]] + smt.YieldOp([twice_symb])