-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Transform][Python] expose transform.debug extension in Python #145550
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesFull diff: https://github.com/llvm/llvm-project/pull/145550.diff 4 Files Affected:
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index ee07081246fc7..b2daabb2a5957 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_pdl_extension)
+declare_mlir_dialect_extension_python_bindings(
+ADD_TO_PARENT MLIRPythonSources.Dialects
+ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/TransformDebugExtensionOps.td
+ SOURCES
+ dialects/transform/debug.py
+ DIALECT_NAME transform
+ EXTENSION_NAME transform_debug_extension)
+
declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
diff --git a/mlir/python/mlir/dialects/TransformDebugExtensionOps.td b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td
new file mode 100644
index 0000000000000..22a85d2366994
--- /dev/null
+++ b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td
@@ -0,0 +1,19 @@
+//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Entry point of the generated Python bindings for the Debug extension of the
+// Transform dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
+
+include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td"
+
+#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py
new file mode 100644
index 0000000000000..738c556b1d362
--- /dev/null
+++ b/mlir/python/mlir/dialects/transform/debug.py
@@ -0,0 +1,86 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from typing import Optional
+
+from ...ir import Attribute, Operation, Value, StringAttr
+from .._transform_debug_extension_ops_gen import *
+from .._transform_pdl_extension_ops_gen import _Dialect
+
+try:
+ from .._ods_common import _cext as _ods_cext
+except ImportError as e:
+ raise RuntimeError("Error loading imports from extension module") from e
+
+from typing import Union
+
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class DebugEmitParamAsRemarkOp(DebugEmitParamAsRemarkOp):
+ def __init__(
+ self,
+ param: Attribute,
+ *,
+ anchor: Optional[Operation] = None,
+ message: Optional[Union[StringAttr, str]] = None,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(message, str):
+ message = StringAttr.get(message)
+
+ super().__init__(
+ param,
+ anchor=anchor,
+ message=message,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def emit_param_as_remark(
+ param: Attribute,
+ *,
+ anchor: Optional[Operation] = None,
+ message: Optional[Union[StringAttr, str]] = None,
+ loc=None,
+ ip=None,
+):
+ return DebugEmitParamAsRemarkOp(
+ param, anchor=anchor, message=message, loc=loc, ip=ip
+ )
+
+del debug_emit_param_as_remark
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class DebugEmitRemarkAtOp(DebugEmitRemarkAtOp):
+ def __init__(
+ self,
+ at: Union[Operation, Value],
+ message: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+ ):
+ if isinstance(message, str):
+ message = StringAttr.get(message)
+
+ super().__init__(
+ at,
+ message,
+ loc=loc,
+ ip=ip,
+ )
+
+
+def emit_remark_at(
+ at: Union[Operation, Value],
+ message: Optional[Union[StringAttr, str]] = None,
+ *,
+ loc=None,
+ ip=None,
+):
+ return DebugEmitRemarkAtOp(at, message, loc=loc, ip=ip)
+
+del debug_emit_remark_at
diff --git a/mlir/test/python/dialects/transform_debug_ext.py b/mlir/test/python/dialects/transform_debug_ext.py
new file mode 100644
index 0000000000000..c96e7e66e03d1
--- /dev/null
+++ b/mlir/test/python/dialects/transform_debug_ext.py
@@ -0,0 +1,47 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import transform
+from mlir.dialects.transform import debug
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ with Context(), Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate,
+ [],
+ transform.AnyOpType.get(),
+ )
+ with InsertionPoint(sequence.body):
+ f(sequence.bodyTarget)
+ transform.YieldOp()
+ print(module)
+ return f
+
+
+@run
+def testDebugEmitParamAsRemark(target):
+ i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
+ i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
+ debug.emit_param_as_remark(i0_param)
+ debug.emit_param_as_remark(i0_param, anchor=target, message="some text")
+ # CHECK-LABEL: TEST: testDebugEmitParamAsRemark
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: %[[PARAM:.*]] = transform.param.constant
+ # CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
+ # CHECK: transform.debug.emit_param_as_remark %[[PARAM]]
+ # CHECK-SAME: "some text"
+ # CHECK-SAME: at %[[ARG0]]
+
+
+@run
+def testDebugEmitRemarkAtOp(target):
+ i0 = IntegerAttr.get(IntegerType.get_signless(32), 0)
+ i0_param = transform.ParamConstantOp(transform.AnyParamType.get(), i0)
+ debug.emit_remark_at(target, "some text")
+ # CHECK-LABEL: TEST: testDebugEmitRemarkAtOp
+ # CHECK: ^{{.*}}(%[[ARG0:.+]]: !transform.any_op):
+ # CHECK: transform.debug.emit_remark_at %[[ARG0]], "some text"
|
|
✅ With the latest revision this PR passed the Python code formatter. |
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
ftynse
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with comments addressed.
mlir/include/mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td
Outdated
Show resolved
Hide resolved
…5550) Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`.
Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like
debug.EmitParamAsRemarkOp/debug.emit_param_as_remarkinstead ofdebug.DebugEmitParamAsRemarkOp/debug.debug_emit_param_as_remark.