Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/IRDL/IR/IRDLAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES
#define MLIR_DIALECT_IRDL_IR_IRDLATTRIBUTES

include "IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/EnumAttr.td"

Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/IRDL/IR/IRDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#ifndef MLIR_DIALECT_IRDL_IR_IRDLOPS
#define MLIR_DIALECT_IRDL_IR_IRDLOPS

include "IRDL.td"
include "IRDLAttributes.td"
include "IRDLTypes.td"
include "IRDLInterfaces.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDLAttributes.td"
include "mlir/Dialect/IRDL/IR/IRDLTypes.td"
include "mlir/Dialect/IRDL/IR/IRDLInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/IRDL/IR/IRDLTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#define MLIR_DIALECT_IRDL_IR_IRDLTYPES

include "mlir/IR/AttrTypeBase.td"
include "IRDL.td"
include "mlir/Dialect/IRDL/IR/IRDL.td"

class IRDL_Type<string name, string typeMnemonic, list<Trait> traits = []>
: TypeDef<IRDL_Dialect, name, traits> {
Expand Down
36 changes: 36 additions & 0 deletions mlir/lib/Bindings/Python/DialectIRDL.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/IRDL.h"
#include "mlir-c/IR.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"

namespace nb = nanobind;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;

static void populateDialectIRDLSubmodule(nb::module_ &m) {
m.def(
"load_dialects",
[](MlirModule module) {
if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module)))
throw std::runtime_error(
"failed to load IRDL dialects from the input module");
},
nb::arg("module"), "Load IRDL dialects from the given module.");
}

NB_MODULE(_mlirDialectsIRDL, m) {
m.doc() = "MLIR IRDL dialect.";

populateDialectIRDLSubmodule(m);
}
23 changes: 23 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,15 @@ declare_mlir_dialect_python_bindings(
GEN_ENUM_BINDINGS_TD_FILE
"dialects/VectorAttributes.td")

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/IRDLOps.td
SOURCES dialects/irdl.py
DIALECT_NAME irdl
GEN_ENUM_BINDINGS
)

################################################################################
# Python extensions.
# The sources for these are all in lib/Bindings/Python, but since they have to
Expand Down Expand Up @@ -645,6 +654,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind
MLIRCAPITransformDialect
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind
MODULE_NAME _mlirDialectsIRDL
ADD_TO_PARENT MLIRPythonSources.Dialects.irdl
ROOT_DIR "${PYTHON_SOURCE_DIR}"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectIRDL.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIIRDL
)

declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses
MODULE_NAME _mlirAsyncPasses
ADD_TO_PARENT MLIRPythonSources.Dialects.async
Expand Down
14 changes: 14 additions & 0 deletions mlir/python/mlir/dialects/IRDLOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_IRDL_OPS
#define PYTHON_BINDINGS_IRDL_OPS

include "mlir/Dialect/IRDL/IR/IRDLOps.td"

#endif
48 changes: 48 additions & 0 deletions mlir/python/mlir/dialects/irdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._irdl_ops_gen import *
from ._irdl_ops_gen import _Dialect
from ._irdl_enum_gen import *
from .._mlir_libs._mlirDialectsIRDL import *
from ..ir import register_attribute_builder
from ._ods_common import (
get_op_result_or_value as _get_value,
get_op_results_or_values as _get_values,
_cext as _ods_cext,
)
from ..extras.meta import region_op


@_ods_cext.register_operation(_Dialect, replace=True)
class DialectOp(DialectOp):
"""Specialization for the dialect op class."""

def __init__(self, sym_name, *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self):
return self.regions[0].blocks[0]


@_ods_cext.register_operation(_Dialect, replace=True)
class OperationOp(OperationOp):
"""Specialization for the operation op class."""

def __init__(self, sym_name, *, loc=None, ip=None):
super().__init__(sym_name, loc=loc, ip=ip)
self.regions[0].blocks.append()

@property
def body(self):
return self.regions[0].blocks[0]


@register_attribute_builder("VariadicityArrayAttr")
def _variadicity_array_attr(x, context):
return _ods_cext.ir.Attribute.parse(
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not an expert on MLIR Python bindings, do you mind explaining what this function is for?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generated builders have hooks for constructing attrs from arguments to the builders themselves (__init__). Just grep for AttrBuilder in the generated code you'll see what's up.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is using the parser to build them the intended way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's one supported way for sure. E.g.

return Attribute.parse(f"#transform.param_operand<index={x}>", context=context)

Is it the "best" way? Probably not - probably the best way is to bind the attrs using mlir_attribute_subclass but I intend to actually deprecate those pure_subclass things soon so this is probably not bad for being "forward compatible" with that deprecation.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried the mlir_attribute_subclass way in the first and it requires more C++ code. By register_attribute_builder we can just construct the attribute in a simple python function, although a parse (and string join) is required which looks a little inelegant.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya it's fine - mlir_attribute_subclass is useful if you want to actually query components of the attribute itself.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some projects go through string representaitons for types, but it is a hack that is frowned upon (at least by me). It is certainly not the intended way to create types, one is expected to provide bindings. It is acceptable to use as a workaround until proper bindings are available.

49 changes: 49 additions & 0 deletions mlir/test/python/dialects/irdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

from mlir.ir import *
from mlir.dialects import irdl
import sys


def run(f):
print("\nTEST:", f.__name__, file=sys.stderr)
f()


# CHECK: TEST: testIRDL
@run
def testIRDL():
with Context() as ctx, Location.unknown():
module = Module.create()
with InsertionPoint(module.body):
dialect = irdl.DialectOp("irdl_test")
with InsertionPoint(dialect.body):
op = irdl.OperationOp("test_op")
with InsertionPoint(op.body):
f32 = irdl.is_(TypeAttr.get(F32Type.get()))
irdl.operands_([f32], ["input"], [irdl.Variadicity.single])

# CHECK: module {
# CHECK: irdl.dialect @irdl_test {
# CHECK: irdl.operation @test_op {
# CHECK: %0 = irdl.is f32
# CHECK: irdl.operands(input: %0)
# CHECK: }
# CHECK: }
# CHECK: }
module.dump()

irdl.load_dialects(module)

m = Module.parse(
"""
module {
%a = arith.constant 1.0 : f32
"irdl_test.test_op"(%a) : (f32) -> ()
}
"""
)
# CHECK: module {
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> ()
# CHECK: }
m.dump()
Loading