-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][Python] Add python bindings for IRDL dialect #158488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
8cbbb6b
9937529
5bafd42
fc7728d
deb2db1
df14b40
4357853
0cbf520
e76fa65
0c7af71
4be1ac3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir-c/Dialect/IRDL.h" | ||
#include "mlir-c/IR.h" | ||
#include "mlir-c/Support.h" | ||
#include "mlir/Bindings/Python/Nanobind.h" | ||
#include "mlir/Bindings/Python/NanobindAdaptors.h" | ||
|
||
namespace nb = nanobind; | ||
using namespace llvm; | ||
using namespace mlir; | ||
using namespace mlir::python; | ||
using namespace mlir::python::nanobind_adaptors; | ||
|
||
static void populateDialectIRDLSubmodule(nb::module_ &m) { | ||
m.def( | ||
"load_dialects", | ||
[](MlirModule module) { | ||
if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module))) | ||
throw std::runtime_error( | ||
"failed to load IRDL dialects from the input module"); | ||
}, | ||
nb::arg("module"), "Load IRDL dialects from the given module."); | ||
} | ||
|
||
NB_MODULE(_mlirDialectsIRDL, m) { | ||
m.doc() = "MLIR IRDL dialect."; | ||
|
||
populateDialectIRDLSubmodule(m); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
//===-- IRDLOps.td - Entry point for IRDL bind ---------*- tablegen -*-===// | ||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef PYTHON_BINDINGS_IRDL_OPS | ||
#define PYTHON_BINDINGS_IRDL_OPS | ||
|
||
include "mlir/Dialect/IRDL/IR/IRDLOps.td" | ||
|
||
#endif |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -0,0 +1,74 @@ | ||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
# See https://llvm.org/LICENSE.txt for license information. | ||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
|
||||
from ._irdl_ops_gen import * | ||||
from ._irdl_ops_gen import _Dialect | ||||
from ._irdl_enum_gen import * | ||||
from .._mlir_libs._mlirDialectsIRDL import * | ||||
from ..ir import register_attribute_builder | ||||
from ._ods_common import ( | ||||
get_op_result_or_value as _get_value, | ||||
get_op_results_or_values as _get_values, | ||||
_cext as _ods_cext, | ||||
) | ||||
from ..extras.meta import region_op | ||||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True) | ||||
class DialectOp(DialectOp): | ||||
"""Specialization for the dialect op class.""" | ||||
|
||||
def __init__(self, sym_name, *, loc=None, ip=None): | ||||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
super().__init__(sym_name, loc=loc, ip=ip) | ||||
self.regions[0].blocks.append() | ||||
|
||||
@property | ||||
def body(self): | ||||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
return self.regions[0].blocks[0] | ||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True) | ||||
class OperationOp(OperationOp): | ||||
PragmaTwice marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
"""Specialization for the operation op class.""" | ||||
|
||||
def __init__(self, sym_name, *, loc=None, ip=None): | ||||
super().__init__(sym_name, loc=loc, ip=ip) | ||||
self.regions[0].blocks.append() | ||||
|
||||
@property | ||||
def body(self): | ||||
return self.regions[0].blocks[0] | ||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True) | ||||
class TypeOp(TypeOp): | ||||
"""Specialization for the type op class.""" | ||||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||
|
||||
def __init__(self, sym_name, *, loc=None, ip=None): | ||||
super().__init__(sym_name, loc=loc, ip=ip) | ||||
self.regions[0].blocks.append() | ||||
|
||||
@property | ||||
def body(self): | ||||
return self.regions[0].blocks[0] | ||||
|
||||
|
||||
@_ods_cext.register_operation(_Dialect, replace=True) | ||||
class AttributeOp(AttributeOp): | ||||
"""Specialization for the attribute op class.""" | ||||
|
||||
def __init__(self, sym_name, *, loc=None, ip=None): | ||||
super().__init__(sym_name, loc=loc, ip=ip) | ||||
self.regions[0].blocks.append() | ||||
|
||||
@property | ||||
def body(self): | ||||
return self.regions[0].blocks[0] | ||||
|
||||
|
||||
@register_attribute_builder("VariadicityArrayAttr") | ||||
def _variadicity_array_attr(x, context): | ||||
return _ods_cext.ir.Attribute.parse( | ||||
f"#irdl<variadicity_array [{', '.join(str(i) for i in x)}]>" | ||||
) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not an expert on MLIR Python bindings, do you mind explaining what this function is for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The generated builders have hooks for constructing attrs from arguments to the builders themselves ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So is using the parser to build them the intended way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's one supported way for sure. E.g.
Is it the "best" way? Probably not - probably the best way is to bind the attrs using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ya it's fine - There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some projects go through string representaitons for types, but it is a hack that is frowned upon (at least by me). It is certainly not the intended way to create types, one is expected to provide bindings. It is acceptable to use as a workaround until proper bindings are available. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# RUN: %PYTHON %s 2>&1 | FileCheck %s | ||
|
||
from mlir.ir import * | ||
from mlir.dialects import irdl | ||
import sys | ||
|
||
|
||
def run(f): | ||
print("\nTEST:", f.__name__, file=sys.stderr) | ||
f() | ||
|
||
|
||
# CHECK: TEST: testIRDL | ||
@run | ||
def testIRDL(): | ||
with Context() as ctx, Location.unknown(): | ||
module = Module.create() | ||
with InsertionPoint(module.body): | ||
dialect = irdl.DialectOp("irdl_test") | ||
with InsertionPoint(dialect.body): | ||
op = irdl.OperationOp("test_op") | ||
with InsertionPoint(op.body): | ||
f32 = irdl.is_(TypeAttr.get(F32Type.get())) | ||
irdl.operands_([f32], ["input"], [irdl.Variadicity.single]) | ||
type1 = irdl.TypeOp("type1") | ||
PragmaTwice marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
with InsertionPoint(type1.body): | ||
f32 = irdl.is_(TypeAttr.get(F32Type.get())) | ||
irdl.parameters([f32], ["val"]) | ||
attr1 = irdl.AttributeOp("attr1") | ||
with InsertionPoint(attr1.body): | ||
test = irdl.is_(StringAttr.get("test")) | ||
irdl.parameters([test], ["val"]) | ||
|
||
# CHECK: module { | ||
# CHECK: irdl.dialect @irdl_test { | ||
# CHECK: irdl.operation @test_op { | ||
# CHECK: %0 = irdl.is f32 | ||
# CHECK: irdl.operands(input: %0) | ||
# CHECK: } | ||
# CHECK: irdl.type @type1 { | ||
# CHECK: %0 = irdl.is f32 | ||
# CHECK: irdl.parameters(val: %0) | ||
# CHECK: } | ||
# CHECK: irdl.attribute @attr1 { | ||
# CHECK: %0 = irdl.is "test" | ||
# CHECK: irdl.parameters(val: %0) | ||
# CHECK: } | ||
# CHECK: } | ||
# CHECK: } | ||
module.operation.verify() | ||
module.dump() | ||
|
||
irdl.load_dialects(module) | ||
|
||
m = Module.parse( | ||
""" | ||
module { | ||
%a = arith.constant 1.0 : f32 | ||
"irdl_test.test_op"(%a) : (f32) -> () | ||
} | ||
""" | ||
) | ||
# CHECK: module { | ||
# CHECK: "irdl_test.test_op"(%cst) : (f32) -> () | ||
# CHECK: } | ||
m.dump() |
Uh oh!
There was an error while loading. Please reload this page.