Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES

include "IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
#define MLIR_DIALECT_IRDL_IR_IRDLOPS

include "IRDL.td"
include "IRDLAttributes.td"
include "IRDLTypes.td"
include "IRDLInterfaces.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES

include "mlir/IR/AttrTypeBase.td"
include "IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"

class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<IRDL_Dialect, name, traits> {
Expand Down
35 changes: 35 additions & 0 deletions mlir/lib/Bindings/Python/DialectIRDL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/IRDL.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;

static void populateDialectIRDLSubmodule(nb::module_ &m) {
m.def(
"load_dialects",
[](MlirModule module) {
if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
throw std::runtime_error(
"failed to load IRDL dialects from the input module");
},
nb::arg("module"), "Load IRDL dialects from the given module.");
}

NB_MODULE(_mlirDialectsIRDL, m) {
m.doc() = "MLIR IRDL dialect.";

populateDialectIRDLSubmodule(m);
}
23 changes: 23 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
SOURCES dialects/irdl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)

################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
Expand Down Expand Up @@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
MODULE_NAME _mlirDialectsIRDL
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectIRDL.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIIRDL
)

declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/IRDLOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_IRDL_OPS
#define PYTHON_BINDINGS_IRDL_OPS

include "mlir/Dialect/IRDL/IR/IRDLOps.td"

#endif
92 changes: 92 additions & 0 deletions mlir/python/mlir/dialects/irdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._irdl_ops_gen import *
from ._irdl_ops_gen import _Dialect
from ._irdl_enum_gen import *
from .._mlir_libs._mlirDialectsIRDL import *
from ..ir import register_attribute_builder
from ._ods_common import _cext as _ods_cext
from typing import Union, Sequence

_ods_ir = _ods_cext.ir


@_ods_cext.register_operation(_Dialect, replace=True)
class DialectOp(DialectOp):
__doc__ = DialectOp.__doc__

def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self) -> _ods_ir.Block:
return self.regions[0].blocks[0]


def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp:
return DialectOp(sym_name=sym_name, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class OperationOp(OperationOp):
__doc__ = OperationOp.__doc__

def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self) -> _ods_ir.Block:
return self.regions[0].blocks[0]


def operation_(
sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
) -> OperationOp:
return OperationOp(sym_name=sym_name, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class TypeOp(TypeOp):
__doc__ = TypeOp.__doc__

def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self) -> _ods_ir.Block:
return self.regions[0].blocks[0]


def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp:
return TypeOp(sym_name=sym_name, loc=loc, ip=ip)


@_ods_cext.register_operation(_Dialect, replace=True)
class AttributeOp(AttributeOp):
__doc__ = AttributeOp.__doc__

def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self) -> _ods_ir.Block:
return self.regions[0].blocks[0]


def attribute(
sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None
) -> AttributeOp:
return AttributeOp(sym_name=sym_name, loc=loc, ip=ip)


@register_attribute_builder("VariadicityArrayAttr")
def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute:
return _ods_ir.Attribute.parse(
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>", context
)
66 changes: 66 additions & 0 deletions mlir/test/python/dialects/irdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

from mlir.ir import *
from mlir.dialects.irdl import *
import sys


def run(f):
print("\nTEST:", f.__name__, file=sys.stderr)
f()


# CHECK: TEST: testIRDL
@run
def testIRDL():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
irdl_test = dialect("irdl_test")
with InsertionPoint(irdl_test.body):
op = operation_("test_op")
with InsertionPoint(op.body):
f32 = is_(TypeAttr.get(F32Type.get()))
operands_([f32], ["input"], [Variadicity.single])
type1 = type_("type1")
with InsertionPoint(type1.body):
f32 = is_(TypeAttr.get(F32Type.get()))
parameters([f32], ["val"])
attr1 = attribute("attr1")
with InsertionPoint(attr1.body):
test = is_(StringAttr.get("test"))
parameters([test], ["val"])

# CHECK: module {
# CHECK: irdl.dialect @irdl_test {
# CHECK: irdl.operation @test_op {
# CHECK: %0 = irdl.is f32
# CHECK: irdl.operands(input: %0)
# CHECK: }
# CHECK: irdl.type @type1 {
# CHECK: %0 = irdl.is f32
# CHECK: irdl.parameters(val: %0)
# CHECK: }
# CHECK: irdl.attribute @attr1 {
# CHECK: %0 = irdl.is "test"
# CHECK: irdl.parameters(val: %0)
# CHECK: }
# CHECK: }
# CHECK: }
module.operation.verify()
module.dump()

load_dialects(module)

m = Module.parse(
"""
module {
%a = arith.constant 1.0 : f32
"irdl_test.test_op"(%a) : (f32) -> ()
}
"""
)
# CHECK: module {
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
# CHECK: }
m.dump()