Skip to content

Commit 8cbbb6b

Browse files
committed
[MLIR][Python] Add python bindings for IRDL dialect
1 parent ad9d551 commit 8cbbb6b

File tree

8 files changed

+167
-6
lines changed

8 files changed

+167
-6
lines changed

mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
1414
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
1515

16-
include "IRDL.td"
16+
include "mlir/Dialect/IRDL/IR/IRDL.td"
1717
include "mlir/IR/AttrTypeBase.td"
1818
include "mlir/IR/EnumAttr.td"
1919

mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
1414
#define MLIR_DIALECT_IRDL_IR_IRDLOPS
1515

16-
include "IRDL.td"
17-
include "IRDLAttributes.td"
18-
include "IRDLTypes.td"
19-
include "IRDLInterfaces.td"
16+
include "mlir/Dialect/IRDL/IR/IRDL.td"
17+
include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
18+
include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
19+
include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
2020
include "mlir/Interfaces/SideEffectInterfaces.td"
2121
include "mlir/Interfaces/InferTypeOpInterface.td"
2222
include "mlir/IR/SymbolInterfaces.td"

mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES
1515

1616
include "mlir/IR/AttrTypeBase.td"
17-
include "IRDL.td"
17+
include "mlir/Dialect/IRDL/IR/IRDL.td"
1818

1919
class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
2020
: TypeDef<IRDL_Dialect, name, traits> {
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/IRDL.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir-c/Support.h"
12+
#include "mlir/Bindings/Python/Nanobind.h"
13+
#include "mlir/Bindings/Python/NanobindAdaptors.h"
14+
15+
namespace nb = nanobind;
16+
using namespace llvm;
17+
using namespace mlir;
18+
using namespace mlir::python;
19+
using namespace mlir::python::nanobind_adaptors;
20+
21+
static void populateDialectIRDLSubmodule(nb::module_ &m) {
22+
m.def(
23+
"load_dialects",
24+
[](MlirModule module) {
25+
if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
26+
throw std::runtime_error(
27+
"failed to load IRDL dialects from the input module");
28+
},
29+
nb::arg("module"), "Load IRDL dialects from the given module.");
30+
}
31+
32+
NB_MODULE(_mlirDialectsIRDL, m) {
33+
m.doc() = "MLIR IRDL dialect.";
34+
35+
populateDialectIRDLSubmodule(m);
36+
}

mlir/python/CMakeLists.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
470470
GEN_ENUM_BINDINGS_TD_FILE
471471
"dialects/VectorAttributes.td")
472472

473+
declare_mlir_dialect_python_bindings(
474+
ADD_TO_PARENT MLIRPythonSources.Dialects
475+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
476+
TD_FILE dialects/IRDLOps.td
477+
SOURCES dialects/irdl.py
478+
DIALECT_NAME irdl
479+
GEN_ENUM_BINDINGS
480+
)
481+
473482
################################################################################
474483
# Python extensions.
475484
# The sources for these are all in lib/Bindings/Python, but since they have to
@@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
645654
MLIRCAPITransformDialect
646655
)
647656

657+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
658+
MODULE_NAME _mlirDialectsIRDL
659+
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
660+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
661+
PYTHON_BINDINGS_LIBRARY nanobind
662+
SOURCES
663+
DialectIRDL.cpp
664+
PRIVATE_LINK_LIBS
665+
LLVMSupport
666+
EMBED_CAPI_LINK_LIBS
667+
MLIRCAPIIR
668+
MLIRCAPIIRDL
669+
)
670+
648671
declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
649672
MODULE_NAME _mlirAsyncPasses
650673
ADD_TO_PARENT MLIRPythonSources.Dialects.async
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_IRDL_OPS
10+
#define PYTHON_BINDINGS_IRDL_OPS
11+
12+
include "mlir/Dialect/IRDL/IR/IRDLOps.td"
13+
14+
#endif

mlir/python/mlir/dialects/irdl.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ._irdl_ops_gen import *
6+
from ._irdl_ops_gen import _Dialect
7+
from ._irdl_enum_gen import *
8+
from .._mlir_libs._mlirDialectsIRDL import *
9+
from ..ir import register_attribute_builder
10+
from ._ods_common import (
11+
get_op_result_or_value as _get_value,
12+
get_op_results_or_values as _get_values,
13+
_cext as _ods_cext,
14+
)
15+
from ..extras.meta import region_op
16+
17+
@_ods_cext.register_operation(_Dialect, replace=True)
18+
class DialectOp(DialectOp):
19+
"""Specialization for the dialect op class."""
20+
21+
def __init__(self, sym_name, *, loc=None, ip=None):
22+
super().__init__(sym_name, loc=loc, ip=ip)
23+
self.regions[0].blocks.append()
24+
25+
@property
26+
def body(self):
27+
return self.regions[0].blocks[0]
28+
29+
@_ods_cext.register_operation(_Dialect, replace=True)
30+
class OperationOp(OperationOp):
31+
"""Specialization for the operation op class."""
32+
33+
def __init__(self, sym_name, *, loc=None, ip=None):
34+
super().__init__(sym_name, loc=loc, ip=ip)
35+
self.regions[0].blocks.append()
36+
37+
@property
38+
def body(self):
39+
return self.regions[0].blocks[0]
40+
41+
@register_attribute_builder("VariadicityArrayAttr")
42+
def _variadicity_array_attr(x, context):
43+
return _ods_cext.ir.Attribute.parse(f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>")

mlir/test/python/dialects/irdl.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# RUN: %PYTHON %s 2>&1 | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import irdl
5+
import sys
6+
7+
8+
def run(f):
9+
print("\nTEST:", f.__name__, file=sys.stderr)
10+
f()
11+
12+
13+
# CHECK: TEST: testIRDL
14+
@run
15+
def testIRDL():
16+
with Context() as ctx, Location.unknown():
17+
module = Module.create()
18+
with InsertionPoint(module.body):
19+
dialect = irdl.DialectOp("irdl_test")
20+
with InsertionPoint(dialect.body):
21+
op = irdl.OperationOp("test_op")
22+
with InsertionPoint(op.body):
23+
f32 = irdl.is_(TypeAttr.get(F32Type.get()))
24+
irdl.operands_([f32], ["input"], [irdl.Variadicity.single])
25+
26+
# CHECK: module {
27+
# CHECK: irdl.dialect @irdl_test {
28+
# CHECK: irdl.operation @test_op {
29+
# CHECK: %0 = irdl.is f32
30+
# CHECK: irdl.operands(input: %0)
31+
# CHECK: }
32+
# CHECK: }
33+
# CHECK: }
34+
module.dump()
35+
36+
irdl.load_dialects(module)
37+
38+
m = Module.parse("""
39+
module {
40+
%a = arith.constant 1.0 : f32
41+
"irdl_test.test_op"(%a) : (f32) -> ()
42+
}
43+
""")
44+
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
45+
m.dump()

0 commit comments

Comments
 (0)