Skip to content

Commit ed21c92

Browse files
committed
[mlir] Introduce Python bindings for the PDL dialect
This change adds full python bindings for PDL, including types and operations with additional mixins to make operation construction more similar to the PDL syntax. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117458
1 parent a889099 commit ed21c92

File tree

12 files changed

+975
-2
lines changed

12 files changed

+975
-2
lines changed

mlir/include/mlir-c/Dialect/PDL.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type);
4949

5050
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType);
5151

52+
MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type);
53+
5254
//===---------------------------------------------------------------------===//
5355
// TypeType
5456
//===---------------------------------------------------------------------===//
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir-c/Dialect/PDL.h"
10+
#include "mlir-c/IR.h"
11+
#include "mlir/Bindings/Python/PybindAdaptors.h"
12+
13+
namespace py = pybind11;
14+
using namespace llvm;
15+
using namespace mlir;
16+
using namespace mlir::python;
17+
using namespace mlir::python::adaptors;
18+
19+
void populateDialectPDLSubmodule(const pybind11::module &m) {
20+
//===-------------------------------------------------------------------===//
21+
// PDLType
22+
//===-------------------------------------------------------------------===//
23+
24+
auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType);
25+
26+
//===-------------------------------------------------------------------===//
27+
// AttributeType
28+
//===-------------------------------------------------------------------===//
29+
30+
auto attributeType =
31+
mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType);
32+
attributeType.def_classmethod(
33+
"get",
34+
[](py::object cls, MlirContext ctx) {
35+
return cls(mlirPDLAttributeTypeGet(ctx));
36+
},
37+
"Get an instance of AttributeType in given context.", py::arg("cls"),
38+
py::arg("context") = py::none());
39+
40+
//===-------------------------------------------------------------------===//
41+
// OperationType
42+
//===-------------------------------------------------------------------===//
43+
44+
auto operationType =
45+
mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType);
46+
operationType.def_classmethod(
47+
"get",
48+
[](py::object cls, MlirContext ctx) {
49+
return cls(mlirPDLOperationTypeGet(ctx));
50+
},
51+
"Get an instance of OperationType in given context.", py::arg("cls"),
52+
py::arg("context") = py::none());
53+
54+
//===-------------------------------------------------------------------===//
55+
// RangeType
56+
//===-------------------------------------------------------------------===//
57+
58+
auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType);
59+
rangeType.def_classmethod(
60+
"get",
61+
[](py::object cls, MlirType elementType) {
62+
return cls(mlirPDLRangeTypeGet(elementType));
63+
},
64+
"Gets an instance of RangeType in the same context as the provided "
65+
"element type.",
66+
py::arg("cls"), py::arg("element_type"));
67+
rangeType.def_property_readonly(
68+
"element_type",
69+
[](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
70+
"Get the element type.");
71+
72+
//===-------------------------------------------------------------------===//
73+
// TypeType
74+
//===-------------------------------------------------------------------===//
75+
76+
auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType);
77+
typeType.def_classmethod(
78+
"get",
79+
[](py::object cls, MlirContext ctx) {
80+
return cls(mlirPDLTypeTypeGet(ctx));
81+
},
82+
"Get an instance of TypeType in given context.", py::arg("cls"),
83+
py::arg("context") = py::none());
84+
85+
//===-------------------------------------------------------------------===//
86+
// ValueType
87+
//===-------------------------------------------------------------------===//
88+
89+
auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType);
90+
valueType.def_classmethod(
91+
"get",
92+
[](py::object cls, MlirContext ctx) {
93+
return cls(mlirPDLValueTypeGet(ctx));
94+
},
95+
"Get an instance of TypeType in given context.", py::arg("cls"),
96+
py::arg("context") = py::none());
97+
}
98+
99+
PYBIND11_MODULE(_mlirDialectsPDL, m) {
100+
m.doc() = "MLIR PDL dialect.";
101+
populateDialectPDLSubmodule(m);
102+
}

mlir/lib/CAPI/Dialect/PDL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) {
6060
return wrap(pdl::RangeType::get(unwrap(elementType)));
6161
}
6262

63+
MlirType mlirPDLRangeTypeGetElementType(MlirType type) {
64+
return wrap(unwrap(type).cast<pdl::RangeType>().getElementType());
65+
}
66+
6367
//===---------------------------------------------------------------------===//
6468
// TypeType
6569
//===---------------------------------------------------------------------===//

mlir/python/CMakeLists.txt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ declare_mlir_python_sources(
123123
dialects/quant.py
124124
_mlir_libs/_mlir/dialects/quant.pyi)
125125

126+
declare_mlir_python_sources(
127+
MLIRPythonSources.Dialects.pdl
128+
ADD_TO_PARENT MLIRPythonSources.Dialects
129+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
130+
SOURCES
131+
dialects/pdl.py
132+
dialects/_pdl_ops_ext.py
133+
_mlir_libs/_mlir/dialects/pdl.pyi)
134+
126135
declare_mlir_dialect_python_bindings(
127136
ADD_TO_PARENT MLIRPythonSources.Dialects
128137
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
@@ -243,6 +252,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind
243252
MLIRCAPIQuant
244253
)
245254

255+
declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind
256+
MODULE_NAME _mlirDialectsPDL
257+
ADD_TO_PARENT MLIRPythonSources.Dialects.pdl
258+
ROOT_DIR "${PYTHON_SOURCE_DIR}"
259+
SOURCES
260+
DialectPDL.cpp
261+
PRIVATE_LINK_LIBS
262+
LLVMSupport
263+
EMBED_CAPI_LINK_LIBS
264+
MLIRCAPIIR
265+
MLIRCAPIPDL
266+
)
267+
246268
declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind
247269
MODULE_NAME _mlirDialectsSparseTensor
248270
ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from typing import Optional
6+
7+
from mlir.ir import Type, Context
8+
9+
__all__ = [
10+
'PDLType',
11+
'AttributeType',
12+
'OperationType',
13+
'RangeType',
14+
'TypeType',
15+
'ValueType',
16+
]
17+
18+
19+
class PDLType(Type):
20+
@staticmethod
21+
def isinstance(type: Type) -> bool: ...
22+
23+
24+
class AttributeType(Type):
25+
@staticmethod
26+
def isinstance(type: Type) -> bool: ...
27+
28+
@staticmethod
29+
def get(context: Optional[Context] = None) -> AttributeType: ...
30+
31+
32+
class OperationType(Type):
33+
@staticmethod
34+
def isinstance(type: Type) -> bool: ...
35+
36+
@staticmethod
37+
def get(context: Optional[Context] = None) -> OperationType: ...
38+
39+
40+
class RangeType(Type):
41+
@staticmethod
42+
def isinstance(type: Type) -> bool: ...
43+
44+
@staticmethod
45+
def get(element_type: Type) -> RangeType: ...
46+
47+
@property
48+
def element_type(self) -> Type: ...
49+
50+
51+
class TypeType(Type):
52+
@staticmethod
53+
def isinstance(type: Type) -> bool: ...
54+
55+
@staticmethod
56+
def get(context: Optional[Context] = None) -> TypeType: ...
57+
58+
59+
class ValueType(Type):
60+
@staticmethod
61+
def isinstance(type: Type) -> bool: ...
62+
63+
@staticmethod
64+
def get(context: Optional[Context] = None) -> ValueType: ...

mlir/python/mlir/dialects/PDLOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef PYTHON_BINDINGS_PDL_OPS
10+
#define PYTHON_BINDINGS_PDL_OPS
11+
12+
include "mlir/Bindings/Python/Attributes.td"
13+
include "mlir/Dialect/PDL/IR/PDLOps.td"
14+
15+
#endif

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def get_op_result_or_value(
144144

145145

146146
def get_op_results_or_values(
147-
arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]]
147+
arg: _Union[_cext.ir.OpView, _cext.ir.Operation,
148+
_Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]]
148149
) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]:
149150
"""Returns the given sequence of values or the results of the given op.
150151
@@ -157,4 +158,4 @@ def get_op_results_or_values(
157158
elif isinstance(arg, _cext.ir.Operation):
158159
return arg.results
159160
else:
160-
return arg
161+
return [get_op_result_or_value(element) for element in arg]

0 commit comments

Comments
 (0)