Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS SMTExtensionOps.td)
mlir_tablegen(SMTExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(SMTExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectSMTExtensionOpsIncGen)

add_mlir_doc(SMTExtensionOps SMTExtensionOps Dialects/ -gen-op-doc)
27 changes: 27 additions & 0 deletions mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- SMTExtension.h - SMT extension for Transform dialect -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"

namespace mlir {
class DialectRegistry;

namespace transform {
/// Registers the SMT extension of the Transform dialect in the given registry.
void registerSMTExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSION_H
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- SMTExtensionOps.h - SMT extension for Transform dialect --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//===- SMTExtensionOps.td - Transform dialect operations ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

def ConstrainParamsOp : Op<Transform_Dialect, "smt.constrain_params", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
NoTerminator
]> {
let cppNamespace = [{ mlir::transform::smt }];

let summary = "Express contraints on params interpreted as symbolic values";
let description = [{
Allows expressing constraints on params using the SMT dialect.

Each Transform dialect param provided as an operand has a corresponding
argument of SMT-type in the region. The SMT-Dialect ops in the region use
these arguments as operands.

The semantics of this op is that all the ops in the region together express
a constraint on the params-interpreted-as-smt-vars. The op fails in case the
expressed constraint is not satisfiable per SMTLIB semantics. Otherwise the
op succeeds.

---

TODO: currently the operational semantics per the Transform interpreter is
to always fail. The intention is build out support for hooking in your own
operational semantics so you can invoke your favourite solver to determine
satisfiability of the corresponding constraint problem.
}];

let arguments = (ins Variadic<TransformParamTypeInterface>:$params);
let regions = (region SizedRegion<1>:$body);
let assemblyFormat =
"`(` $params `)` attr-dict `:` type(operands) $body";

let hasVerifier = 1;
}

#endif // MLIR_DIALECT_TRANSFORM_SMTEXTENSION_SMTEXTENSIONOPS
25 changes: 15 additions & 10 deletions mlir/lib/Bindings/Python/DialectSMT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors;

static void populateDialectSMTSubmodule(nanobind::module_ &m) {

auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_classmethod(
"get",
[](const nb::object &, MlirContext context) {
return mlirSMTTypeGetBool(context);
},
"cls"_a, "context"_a = nb::none());
auto smtBoolType =
mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool)
.def_staticmethod(
"get",
[](MlirContext context) { return mlirSMTTypeGetBool(context); },
"context"_a = nb::none());
auto smtBitVectorType =
mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector)
.def_classmethod(
.def_staticmethod(
"get",
[](const nb::object &, int32_t width, MlirContext context) {
[](int32_t width, MlirContext context) {
return mlirSMTTypeGetBitVector(context, width);
},
"cls"_a, "width"_a, "context"_a = nb::none());
"width"_a, "context"_a = nb::none());
auto smtIntType =
mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt)
.def_staticmethod(
"get",
[](MlirContext context) { return mlirSMTTypeGetInt(context); },
"context"_a = nb::none());

auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues,
bool indentLetBody) {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_subdirectory(IR)
add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(SMTExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
add_subdirectory(Utils)
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Transform/SMTExtension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRTransformSMTExtension
SMTExtension.cpp
SMTExtensionOps.cpp

DEPENDS
MLIRTransformDialectSMTExtensionOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
MLIRSMT
)
35 changes: 35 additions & 0 deletions mlir/lib/Dialect/Transform/SMTExtension/SMTExtension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===- SMTExtension.cpp - SMT extension for the Transform dialect ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//

namespace {
class SMTExtension : public transform::TransformDialectExtension<SMTExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SMTExtension)

SMTExtension() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"
>();
}
};
} // namespace

void mlir::transform::registerSMTExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<SMTExtension>();
}
55 changes: 55 additions & 0 deletions mlir/lib/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//===- SMTExtensionOps.cpp - SMT extension for the Transform dialect ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.h"
#include "mlir/Dialect/SMT/IR/SMTDialect.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"

using namespace mlir;

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.cpp.inc"

//===----------------------------------------------------------------------===//
// ConstrainParamsOp
//===----------------------------------------------------------------------===//

void transform::smt::ConstrainParamsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getParamsMutable(), effects);
}

DiagnosedSilenceableFailure
transform::smt::ConstrainParamsOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
// TODO: Proper operational semantics are to check the SMT problem in the body
// with a SMT solver with the arguments of the body constrained to the
// values passed into the op. Success or failure is then determined by
// the solver's result.
// One way to support this is to just promise the TransformOpInterface
// and allow for users to attach their own implementation, which would,
// e.g., translate the ops to SMTLIB and hand that over to the user's
// favourite solver. This requires changes to the dialect's verifier.
return emitDefiniteFailure() << "op does not have interpreted semantics yet";
}

LogicalResult transform::smt::ConstrainParamsOp::verify() {
if (getOperands().size() != getBody().getNumArguments())
return emitOpError(
"must have the same number of block arguments as operands");

for (auto &op : getBody().getOps()) {
if (!isa<mlir::smt::SMTDialect>(op.getDialect()))
return emitOpError(
"ops contained in region should belong to SMT-dialect");
}

return success();
}
2 changes: 2 additions & 0 deletions mlir/lib/RegisterAllExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Transform/SMTExtension/SMTExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
Expand Down Expand Up @@ -108,6 +109,7 @@ void mlir::registerAllExtensions(DialectRegistry &registry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
transform::registerSMTExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
Expand Down
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformSMTExtensionOps.td
SOURCES
dialects/transform/smt.py
DIALECT_NAME transform
EXTENSION_NAME transform_smt_extension)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
19 changes: 19 additions & 0 deletions mlir/python/mlir/dialects/TransformSMTExtensionOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the SMT extension of the
// Transform dialect.
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS

include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td"

#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/smt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._smt_ops_gen import *
from ._smt_enum_gen import *

from .._mlir_libs._mlirDialectsSMT import *
from ..extras.meta import region_op
Expand Down
38 changes: 38 additions & 0 deletions mlir/python/mlir/dialects/transform/smt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Sequence

from ...ir import Type, Block
from .._transform_smt_extension_ops_gen import *
from .._transform_smt_extension_ops_gen import _Dialect
from ...dialects import transform

try:
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e


@_ods_cext.register_operation(_Dialect, replace=True)
class ConstrainParamsOp(ConstrainParamsOp):
def __init__(
self,
params: Sequence[transform.AnyParamType],
arg_types: Sequence[Type],
loc=None,
ip=None,
):
if len(params) != len(arg_types):
raise ValueError(f"{params=} not same length as {arg_types=}")
super().__init__(
params,
loc=loc,
ip=ip,
)
self.regions[0].blocks.append(*arg_types)

@property
def body(self) -> Block:
return self.regions[0].blocks[0]
30 changes: 30 additions & 0 deletions mlir/test/Dialect/Transform/test-smt-extension-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics

// CHECK-LABEL: @constraint_not_using_smt_ops
module attributes {transform.with_named_sequence} {
transform.named_sequence @constraint_not_using_smt_ops(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{ops contained in region should belong to SMT-dialect}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
^bb0(%param_as_smt_var: !smt.int):
%c4 = arith.constant 4 : i32
// This is the kind of thing one might think works:
//arith.remsi %param_as_smt_var, %c4 : i32
}
transform.yield
}
}

// -----

// CHECK-LABEL: @operands_not_one_to_one_with_vars
module attributes {transform.with_named_sequence} {
transform.named_sequence @operands_not_one_to_one_with_vars(%arg0: !transform.any_op {transform.readonly}) {
%param_as_param = transform.param.constant 42 -> !transform.param<i64>
// expected-error@below {{must have the same number of block arguments as operands}}
transform.smt.constrain_params(%param_as_param) : !transform.param<i64> {
^bb0(%param_as_smt_var: !smt.int, %param_as_another_smt_var: !smt.int):
}
transform.yield
}
}
Loading
Loading