Skip to content

Commit e5114a2

Browse files
authored
[MLIR][Python] Add python bindings for IRDL dialect (#158488)
In this PR we add basic python bindings for IRDL dialect, so that python users can create and load IRDL dialects in python. This allows users, to some extent, to define dialects in Python without having to modify MLIR’s CMake/TableGen/C++ code and rebuild, making prototyping more convenient. A basic example is shown below (and also in the added test case): ```python # create a module with IRDL dialects module = Module.create() with InsertionPoint(module.body): dialect = irdl.DialectOp("irdl_test") with InsertionPoint(dialect.body): op = irdl.OperationOp("test_op") with InsertionPoint(op.body): f32 = irdl.is_(TypeAttr.get(F32Type.get())) irdl.operands_([f32], ["input"], [irdl.Variadicity.single]) # load the module irdl.load_dialects(module) # use the op defined in IRDL m = Module.parse(""" module { %a = arith.constant 1.0 : f32 "irdl_test.test_op"(%a) : (f32) -> () } """) ```
1 parent b26b40b commit e5114a2

File tree

8 files changed

+236
-6
lines changed

8 files changed

+236
-6
lines changed

mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
1414
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
1515

16-
include "IRDL.td"
16+
include "mlir/Dialect/IRDL/IR/IRDL.td"
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/EnumAttr.td"
1919

mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
1414
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
1515

16-
include "IRDL.td"
17-
include "IRDLAttributes.td"
18-
include "IRDLTypes.td"
19-
include "IRDLInterfaces.td"
16+
include "mlir/Dialect/IRDL/IR/IRDL.td"
17+
include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
18+
include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
19+
include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
2020
include "mlir/Interfaces/SideEffectInterfaces.td"
2121
include "mlir/Interfaces/InferTypeOpInterface.td"
2222
include "mlir/IR/SymbolInterfaces.td"

mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES
1515

1616
include "mlir/IR/AttrTypeBase.td"
17-
include "IRDL.td"
17+
include "mlir/Dialect/IRDL/IR/IRDL.td"
1818

1919
class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
2020
: TypeDef<IRDL_Dialect, name, traits> {
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/IRDL.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir-c/Support.h"
12+
#include "mlir/Bindings/Python/Nanobind.h"
13+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
14+
15+
namespace nb = nanobind;
16+
using namespace mlir;
17+
using namespace mlir::python;
18+
using namespace mlir::python::nanobind_adaptors;
19+
20+
static void populateDialectIRDLSubmodule(nb::module_ &m) {
21+
m.def(
22+
"load_dialects",
23+
[](MlirModule module) {
24+
if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
25+
throw std::runtime_error(
26+
"failed to load IRDL dialects from the input module");
27+
},
28+
nb::arg("module"), "Load IRDL dialects from the given module.");
29+
}
30+
31+
NB_MODULE(_mlirDialectsIRDL, m) {
32+
m.doc() = "MLIR IRDL dialect.";
33+
34+
populateDialectIRDLSubmodule(m);
35+
}

mlir/python/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
470470
GEN_ENUM_BINDINGS_TD_FILE
471471
"dialects/VectorAttributes.td")
472472

473+
declare_mlir_dialect_python_bindings(
474+
ADD_TO_PARENT MLIRPythonSources.Dialects
475+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
476+
TD_FILE dialects/IRDLOps.td
477+
SOURCES dialects/irdl.py
478+
DIALECT_NAME irdl
479+
GEN_ENUM_BINDINGS
480+
)
481+
473482
################################################################################
474483
# Python extensions.
475484
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
645654
MLIRCAPITransformDialect
646655
)
647656

657+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
658+
MODULE_NAME _mlirDialectsIRDL
659+
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
660+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
661+
PYTHON_BINDINGS_LIBRARY nanobind
662+
SOURCES
663+
DialectIRDL.cpp
664+
PRIVATE_LINK_LIBS
665+
LLVMSupport
666+
EMBED_CAPI_LINK_LIBS
667+
MLIRCAPIIR
668+
MLIRCAPIIRDL
669+
)
670+
648671
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
649672
MODULE_NAME _mlirAsyncPasses
650673
ADD_TO_PARENT MLIRPythonSources.Dialects.async
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_IRDL_OPS
10+
#define PYTHON_BINDINGS_IRDL_OPS
11+
12+
include "mlir/Dialect/IRDL/IR/IRDLOps.td"
13+
14+
#endif

mlir/python/mlir/dialects/irdl.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._irdl_ops_gen import *
6+
from ._irdl_ops_gen import _Dialect
7+
from ._irdl_enum_gen import *
8+
from .._mlir_libs._mlirDialectsIRDL import *
9+
from ..ir import register_attribute_builder
10+
from ._ods_common import _cext as _ods_cext
11+
from typing import Union, Sequence
12+
13+
_ods_ir = _ods_cext.ir
14+
15+
16+
@_ods_cext.register_operation(_Dialect, replace=True)
17+
class DialectOp(DialectOp):
18+
__doc__ = DialectOp.__doc__
19+
20+
def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
21+
super().__init__(sym_name, loc=loc, ip=ip)
22+
self.regions[0].blocks.append()
23+
24+
@property
25+
def body(self) -> _ods_ir.Block:
26+
return self.regions[0].blocks[0]
27+
28+
29+
def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp:
30+
return DialectOp(sym_name=sym_name, loc=loc, ip=ip)
31+
32+
33+
@_ods_cext.register_operation(_Dialect, replace=True)
34+
class OperationOp(OperationOp):
35+
__doc__ = OperationOp.__doc__
36+
37+
def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
38+
super().__init__(sym_name, loc=loc, ip=ip)
39+
self.regions[0].blocks.append()
40+
41+
@property
42+
def body(self) -> _ods_ir.Block:
43+
return self.regions[0].blocks[0]
44+
45+
46+
def operation_(
47+
sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
48+
) -> OperationOp:
49+
return OperationOp(sym_name=sym_name, loc=loc, ip=ip)
50+
51+
52+
@_ods_cext.register_operation(_Dialect, replace=True)
53+
class TypeOp(TypeOp):
54+
__doc__ = TypeOp.__doc__
55+
56+
def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
57+
super().__init__(sym_name, loc=loc, ip=ip)
58+
self.regions[0].blocks.append()
59+
60+
@property
61+
def body(self) -> _ods_ir.Block:
62+
return self.regions[0].blocks[0]
63+
64+
65+
def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp:
66+
return TypeOp(sym_name=sym_name, loc=loc, ip=ip)
67+
68+
69+
@_ods_cext.register_operation(_Dialect, replace=True)
70+
class AttributeOp(AttributeOp):
71+
__doc__ = AttributeOp.__doc__
72+
73+
def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
74+
super().__init__(sym_name, loc=loc, ip=ip)
75+
self.regions[0].blocks.append()
76+
77+
@property
78+
def body(self) -> _ods_ir.Block:
79+
return self.regions[0].blocks[0]
80+
81+
82+
def attribute(
83+
sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
84+
) -> AttributeOp:
85+
return AttributeOp(sym_name=sym_name, loc=loc, ip=ip)
86+
87+
88+
@register_attribute_builder("VariadicityArrayAttr")
89+
def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute:
90+
return _ods_ir.Attribute.parse(
91+
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>", context
92+
)

mlir/test/python/dialects/irdl.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects.irdl import *
5+
import sys
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__, file=sys.stderr)
10+
f()
11+
12+
13+
# CHECK: TEST: testIRDL
14+
@run
15+
def testIRDL():
16+
with Context() as ctx, Location.unknown():
17+
module = Module.create()
18+
with InsertionPoint(module.body):
19+
irdl_test = dialect("irdl_test")
20+
with InsertionPoint(irdl_test.body):
21+
op = operation_("test_op")
22+
with InsertionPoint(op.body):
23+
f32 = is_(TypeAttr.get(F32Type.get()))
24+
operands_([f32], ["input"], [Variadicity.single])
25+
type1 = type_("type1")
26+
with InsertionPoint(type1.body):
27+
f32 = is_(TypeAttr.get(F32Type.get()))
28+
parameters([f32], ["val"])
29+
attr1 = attribute("attr1")
30+
with InsertionPoint(attr1.body):
31+
test = is_(StringAttr.get("test"))
32+
parameters([test], ["val"])
33+
34+
# CHECK: module {
35+
# CHECK: irdl.dialect @irdl_test {
36+
# CHECK: irdl.operation @test_op {
37+
# CHECK: %0 = irdl.is f32
38+
# CHECK: irdl.operands(input: %0)
39+
# CHECK: }
40+
# CHECK: irdl.type @type1 {
41+
# CHECK: %0 = irdl.is f32
42+
# CHECK: irdl.parameters(val: %0)
43+
# CHECK: }
44+
# CHECK: irdl.attribute @attr1 {
45+
# CHECK: %0 = irdl.is "test"
46+
# CHECK: irdl.parameters(val: %0)
47+
# CHECK: }
48+
# CHECK: }
49+
# CHECK: }
50+
module.operation.verify()
51+
module.dump()
52+
53+
load_dialects(module)
54+
55+
m = Module.parse(
56+
"""
57+
module {
58+
%a = arith.constant 1.0 : f32
59+
"irdl_test.test_op"(%a) : (f32) -> ()
60+
}
61+
"""
62+
)
63+
# CHECK: module {
64+
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
65+
# CHECK: }
66+
m.dump()

0 commit comments

Comments
 (0)