-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add python bindings for IRDL dialect #158488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-irdl @llvm/pr-subscribers-mlir Author: Twice (PragmaTwice) ChangesIn this PR we add basic python bindings for IRDL dialect, so that python users can create and load IRDL dialects in python. This allows users, to some extent, to define dialects in Python without having to modify MLIR’s CMake/TableGen/C++ code and rebuild, making prototyping more convenient. A basic example is shown below (and also in the added test case): # create a module with IRDL dialects
module = Module.create()
with InsertionPoint(module.body):
dialect = irdl.DialectOp("irdl_test")
with InsertionPoint(dialect.body):
op = irdl.OperationOp("test_op")
with InsertionPoint(op.body):
f32 = irdl.is_(TypeAttr.get(F32Type.get()))
irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
# load the module
irdl.load_dialects(module)
# use the op defined in IRDL
m = Module.parse("""
module {
%a = arith.constant 1.0 : f32
"irdl_test.test_op"(%a) : (f32) -> ()
}
""") Full diff: https://github.com/llvm/llvm-project/pull/158488.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
index 1bfa7f5cb894b..2f568e8b6c42a 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
@@ -13,7 +13,7 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
index 4a83eb62fba32..3b6b09973645c 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
@@ -13,10 +13,10 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
-include "IRDL.td"
-include "IRDLAttributes.td"
-include "IRDLTypes.td"
-include "IRDLInterfaces.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
+include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
+include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
diff --git a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
index 9b17bf23df182..9cde433cf33a6 100644
--- a/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
+++ b/mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
@@ -14,7 +14,7 @@
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES
include "mlir/IR/AttrTypeBase.td"
-include "IRDL.td"
+include "mlir/Dialect/IRDL/IR/IRDL.td"
class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<IRDL_Dialect, name, traits> {
diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp
new file mode 100644
index 0000000000000..8264d21d4fa03
--- /dev/null
+++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp
@@ -0,0 +1,36 @@
+//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir-c/Dialect/IRDL.h"
+#include "mlir-c/IR.h"
+#include "mlir-c/Support.h"
+#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
+
+namespace nb = nanobind;
+using namespace llvm;
+using namespace mlir;
+using namespace mlir::python;
+using namespace mlir::python::nanobind_adaptors;
+
+static void populateDialectIRDLSubmodule(nb::module_ &m) {
+ m.def(
+ "load_dialects",
+ [](MlirModule module) {
+ if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
+ throw std::runtime_error(
+ "failed to load IRDL dialects from the input module");
+ },
+ nb::arg("module"), "Load IRDL dialects from the given module.");
+}
+
+NB_MODULE(_mlirDialectsIRDL, m) {
+ m.doc() = "MLIR IRDL dialect.";
+
+ populateDialectIRDLSubmodule(m);
+}
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index c983914722ce1..7b2e1b8c36f25 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")
+declare_mlir_dialect_python_bindings(
+ ADD_TO_PARENT MLIRPythonSources.Dialects
+ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
+ TD_FILE dialects/IRDLOps.td
+ SOURCES dialects/irdl.py
+ DIALECT_NAME irdl
+ GEN_ENUM_BINDINGS
+)
+
################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)
+declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
+ MODULE_NAME _mlirDialectsIRDL
+ ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
+ ROOT_DIR "${PYTHON_SOURCE_DIR}"
+ PYTHON_BINDINGS_LIBRARY nanobind
+ SOURCES
+ DialectIRDL.cpp
+ PRIVATE_LINK_LIBS
+ LLVMSupport
+ EMBED_CAPI_LINK_LIBS
+ MLIRCAPIIR
+ MLIRCAPIIRDL
+)
+
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td
new file mode 100644
index 0000000000000..695839f1aa08b
--- /dev/null
+++ b/mlir/python/mlir/dialects/IRDLOps.td
@@ -0,0 +1,14 @@
+//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef PYTHON_BINDINGS_IRDL_OPS
+#define PYTHON_BINDINGS_IRDL_OPS
+
+include "mlir/Dialect/IRDL/IR/IRDLOps.td"
+
+#endif
diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py
new file mode 100644
index 0000000000000..2314ee99950e0
--- /dev/null
+++ b/mlir/python/mlir/dialects/irdl.py
@@ -0,0 +1,43 @@
+# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+# See https://llvm.org/LICENSE.txt for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+from ._irdl_ops_gen import *
+from ._irdl_ops_gen import _Dialect
+from ._irdl_enum_gen import *
+from .._mlir_libs._mlirDialectsIRDL import *
+from ..ir import register_attribute_builder
+from ._ods_common import (
+ get_op_result_or_value as _get_value,
+ get_op_results_or_values as _get_values,
+ _cext as _ods_cext,
+)
+from ..extras.meta import region_op
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class DialectOp(DialectOp):
+ """Specialization for the dialect op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+@_ods_cext.register_operation(_Dialect, replace=True)
+class OperationOp(OperationOp):
+ """Specialization for the operation op class."""
+
+ def __init__(self, sym_name, *, loc=None, ip=None):
+ super().__init__(sym_name, loc=loc, ip=ip)
+ self.regions[0].blocks.append()
+
+ @property
+ def body(self):
+ return self.regions[0].blocks[0]
+
+@register_attribute_builder("VariadicityArrayAttr")
+def _variadicity_array_attr(x, context):
+ return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")
diff --git a/mlir/test/python/dialects/irdl.py b/mlir/test/python/dialects/irdl.py
new file mode 100644
index 0000000000000..30983af302a52
--- /dev/null
+++ b/mlir/test/python/dialects/irdl.py
@@ -0,0 +1,45 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.dialects import irdl
+import sys
+
+
+def run(f):
+ print("\nTEST:", f.__name__, file=sys.stderr)
+ f()
+
+
+# CHECK: TEST: testIRDL
+@run
+def testIRDL():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ dialect = irdl.DialectOp("irdl_test")
+ with InsertionPoint(dialect.body):
+ op = irdl.OperationOp("test_op")
+ with InsertionPoint(op.body):
+ f32 = irdl.is_(TypeAttr.get(F32Type.get()))
+ irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
+
+ # CHECK: module {
+ # CHECK: irdl.dialect @irdl_test {
+ # CHECK: irdl.operation @test_op {
+ # CHECK: %0 = irdl.is f32
+ # CHECK: irdl.operands(input: %0)
+ # CHECK: }
+ # CHECK: }
+ # CHECK: }
+ module.dump()
+
+ irdl.load_dialects(module)
+
+ m = Module.parse("""
+ module {
+ %a = arith.constant 1.0 : f32
+ "irdl_test.test_op"(%a) : (f32) -> ()
+ }
+ """)
+ # CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
+ m.dump()
|
✅ With the latest revision this PR passed the Python code formatter. |
ping @Moxinilian |
Sorry forgot about this. Will look at it shortly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Other than the comments below this looks good to me, with the caveat that I have never touched the Python bindings so I don't know if this is the right way to do it.
@register_attribute_builder("VariadicityArrayAttr") | ||
def _variadicity_array_attr(x, context): | ||
return _ods_cext.ir.Attribute.parse( | ||
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not an expert on MLIR Python bindings, do you mind explaining what this function is for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generated builders have hooks for constructing attrs from arguments to the builders themselves (__init__
). Just grep for AttrBuilder
in the generated code you'll see what's up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So is using the parser to build them the intended way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's one supported way for sure. E.g.
return Attribute.parse(f"#transform.param_operand<index={x}>", context=context) |
Is it the "best" way? Probably not - probably the best way is to bind the attrs using mlir_attribute_subclass
but I intend to actually deprecate those pure_subclass
things soon so this is probably not bad for being "forward compatible" with that deprecation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried the mlir_attribute_subclass
way in the first and it requires more C++ code. By register_attribute_builder
we can just construct the attribute in a simple python function, although a parse
(and string join
) is required which looks a little inelegant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ya it's fine - mlir_attribute_subclass
is useful if you want to actually query components of the attribute itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some projects go through string representaitons for types, but it is a hack that is frowned upon (at least by me). It is certainly not the intended way to create types, one is expected to provide bindings. It is acceptable to use as a workaround until proper bindings are available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modulo minor things in comments, LGTM. Thanks, @PragmaTwice !
@register_attribute_builder("VariadicityArrayAttr") | ||
def _variadicity_array_attr(x, context): | ||
return _ods_cext.ir.Attribute.parse( | ||
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some projects go through string representaitons for types, but it is a hack that is frowned upon (at least by me). It is certainly not the intended way to create types, one is expected to provide bindings. It is acceptable to use as a workaround until proper bindings are available.
Review suggestions addressed : ) Please check when convenient. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on the IRDL front. I would have appreciated the more-C++ option instead of ad hoc parser calls, but I don't want to block the PR over this.
Thank you all! I'll merge it soon. |
This PR sets up build rules for the Python bindings of the IRDL dialect introduced by llvm#158488. The absence of them does not break the bazel build but some downstream users rely on them. Signed-off-by: Ingo Müller <[email protected]>
This PR sets up build rules for the Python bindings of the IRDL dialect introduced by #158488. The absence of them does not break the bazel build but some downstream users rely on them. Signed-off-by: Ingo Müller <[email protected]>
This is a nice feature to have! As I was thinking of using it, it got me to the next natural question: how do I build these IRDL-defined ops from Python? In particular, how do I create new op instances of these ops among IR that's being build using the Python bindings for non-IRDL-defined ops? Do I have to use the generic op interface ( And taking that as a starting point: Is there a way we could use some Python magic to generate a Builder class for each IRDL-in-Python-defined op and have an API for these ops that naturally integrates with the Python API for statically defined ops? |
Currently that's right. We need to use Operation.create (or Operation.parse) to construct these IRDL-defined operations. And yes, we can generate constructors for these IRDL-defined operations in a pythonic way. (This can be a TODO of IRDL python bindings/follow-up of this PR.) |
In this PR we add basic python bindings for IRDL dialect, so that python users can create and load IRDL dialects in python. This allows users, to some extent, to define dialects in Python without having to modify MLIR’s CMake/TableGen/C++ code and rebuild, making prototyping more convenient.
A basic example is shown below (and also in the added test case):