Skip to content

Commit 3ec4809

Browse files
committed
[MLIR][Python] Add a DSL for defining IRDL dialects in Python bindings
1 parent 202d784 commit 3ec4809

File tree

4 files changed

+661
-7
lines changed

4 files changed

+661
-7
lines changed

mlir/python/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,9 @@ declare_mlir_dialect_python_bindings(
511511
ADD_TO_PARENT MLIRPythonSources.Dialects
512512
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
513513
TD_FILE dialects/IRDLOps.td
514-
SOURCES dialects/irdl.py
514+
SOURCES
515+
dialects/irdl/__init__.py
516+
dialects/irdl/dsl.py
515517
DIALECT_NAME irdl
516518
GEN_ENUM_BINDINGS
517519
)

mlir/python/mlir/dialects/irdl.py renamed to mlir/python/mlir/dialects/irdl/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
44

5-
from ._irdl_ops_gen import *
6-
from ._irdl_ops_gen import _Dialect
7-
from ._irdl_enum_gen import *
8-
from .._mlir_libs._mlirDialectsIRDL import *
9-
from ..ir import register_attribute_builder
10-
from ._ods_common import _cext as _ods_cext
5+
from .._irdl_ops_gen import *
6+
from .._irdl_ops_gen import _Dialect
7+
from .._irdl_enum_gen import *
8+
from ..._mlir_libs._mlirDialectsIRDL import *
9+
from ...ir import register_attribute_builder
10+
from .._ods_common import _cext as _ods_cext
1111
from typing import Union, Sequence
12+
from . import dsl
1213

1314
_ods_ir = _ods_cext.ir
1415

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from ...dialects import irdl as _irdl
6+
from .._ods_common import (
7+
_cext as _ods_cext,
8+
segmented_accessor as _ods_segmented_accessor,
9+
)
10+
from . import Variadicity
11+
from typing import Dict, List, Union, Callable, Tuple
12+
from dataclasses import dataclass
13+
from inspect import Parameter as _Parameter, Signature as _Signature
14+
from types import SimpleNamespace as _SimpleNameSpace
15+
16+
_ods_ir = _ods_cext.ir
17+
18+
19+
class ConstraintExpr:
20+
def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value:
21+
raise NotImplementedError()
22+
23+
def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr":
24+
return AnyOf(self, other)
25+
26+
def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr":
27+
return AllOf(self, other)
28+
29+
30+
class ConstraintLoweringContext:
31+
def __init__(self):
32+
# Cache so that the same ConstraintExpr instance reuses its SSA value.
33+
self._cache: Dict[int, _ods_ir.Value] = {}
34+
35+
def lower(self, expr: ConstraintExpr) -> _ods_ir.Value:
36+
key = id(expr)
37+
if key in self._cache:
38+
return self._cache[key]
39+
v = expr._lower(self)
40+
self._cache[key] = v
41+
return v
42+
43+
44+
class Is(ConstraintExpr):
45+
def __init__(self, attr: _ods_ir.Attribute):
46+
self.attr = attr
47+
48+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
49+
return _irdl.is_(self.attr)
50+
51+
52+
class IsType(Is):
53+
def __init__(self, typ: _ods_ir.Type):
54+
super().__init__(_ods_ir.TypeAttr.get(typ))
55+
56+
57+
class AnyOf(ConstraintExpr):
58+
def __init__(self, *exprs: ConstraintExpr):
59+
self.exprs = exprs
60+
61+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
62+
return _irdl.any_of(ctx.lower(expr) for expr in self.exprs)
63+
64+
65+
class AllOf(ConstraintExpr):
66+
def __init__(self, *exprs: ConstraintExpr):
67+
self.exprs = exprs
68+
69+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
70+
return _irdl.all_of(ctx.lower(expr) for expr in self.exprs)
71+
72+
73+
class Any(ConstraintExpr):
74+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
75+
return _irdl.any()
76+
77+
78+
class BaseName(ConstraintExpr):
79+
def __init__(self, name: str):
80+
self.name = name
81+
82+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
83+
return _irdl.base(base_name=self.name)
84+
85+
86+
class BaseRef(ConstraintExpr):
87+
def __init__(self, ref):
88+
self.ref = ref
89+
90+
def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value:
91+
return _irdl.base(base_ref=self.ref)
92+
93+
94+
class FieldDef:
95+
def __set_name__(self, owner, name: str):
96+
self.name = name
97+
98+
99+
@dataclass
100+
class Operand(FieldDef):
101+
constraint: ConstraintExpr
102+
variadicity: Variadicity = Variadicity.single
103+
104+
105+
@dataclass
106+
class Result(FieldDef):
107+
constraint: ConstraintExpr
108+
variadicity: Variadicity = Variadicity.single
109+
110+
111+
@dataclass
112+
class Attribute(FieldDef):
113+
constraint: ConstraintExpr
114+
115+
def __post_init__(self):
116+
# just for unified processing,
117+
# currently optional attribute is not supported by IRDL
118+
self.variadicity = Variadicity.single
119+
120+
121+
@dataclass
122+
class Operation:
123+
dialect_name: str
124+
name: str
125+
# We store operands and attributes into one list to maintain relative orders
126+
# among them for generating OpView class.
127+
operands_and_attrs: List[Union[Operand, Attribute]]
128+
results: List[Result]
129+
130+
def _emit(self) -> None:
131+
op = _irdl.operation_(self.name)
132+
ctx = ConstraintLoweringContext()
133+
134+
operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
135+
attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
136+
137+
with _ods_ir.InsertionPoint(op.body):
138+
if operands:
139+
_irdl.operands_(
140+
[ctx.lower(i.constraint) for i in operands],
141+
[i.name for i in operands],
142+
[i.variadicity for i in operands],
143+
)
144+
if attrs:
145+
_irdl.attributes_(
146+
[ctx.lower(i.constraint) for i in attrs],
147+
[i.name for i in attrs],
148+
)
149+
if self.results:
150+
_irdl.results_(
151+
[ctx.lower(i.constraint) for i in self.results],
152+
[i.name for i in self.results],
153+
[i.variadicity for i in self.results],
154+
)
155+
156+
def _make_op_view_and_builder(self) -> Tuple[type, Callable]:
157+
operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)]
158+
attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)]
159+
160+
def variadicity_to_segment(variadicity: Variadicity) -> int:
161+
if variadicity == Variadicity.variadic:
162+
return -1
163+
if variadicity == Variadicity.optional:
164+
return 0
165+
return 1
166+
167+
operand_segments = None
168+
if any(i.variadicity != Variadicity.single for i in operands):
169+
operand_segments = [variadicity_to_segment(i.variadicity) for i in operands]
170+
171+
result_segments = None
172+
if any(i.variadicity != Variadicity.single for i in self.results):
173+
result_segments = [
174+
variadicity_to_segment(i.variadicity) for i in self.results
175+
]
176+
177+
args = self.results + self.operands_and_attrs
178+
positional_args = [
179+
i.name for i in args if i.variadicity != Variadicity.optional
180+
]
181+
optional_args = [i.name for i in args if i.variadicity == Variadicity.optional]
182+
183+
params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)]
184+
for i in positional_args:
185+
params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD))
186+
for i in optional_args:
187+
params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None))
188+
params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None))
189+
params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None))
190+
191+
sig = _Signature(params)
192+
op = self
193+
194+
class _OpView(_ods_ir.OpView):
195+
OPERATION_NAME = f"{op.dialect_name}.{op.name}"
196+
_ODS_REGIONS = (0, True)
197+
_ODS_OPERAND_SEGMENTS = operand_segments
198+
_ODS_RESULT_SEGMENTS = result_segments
199+
200+
def __init__(*args, **kwargs):
201+
bound = sig.bind(*args, **kwargs)
202+
bound.apply_defaults()
203+
args = bound.arguments
204+
205+
_operands = [args[operand.name] for operand in operands]
206+
_results = [args[result.name] for result in op.results]
207+
_attributes = dict(
208+
(attr.name, args[attr.name])
209+
for attr in attrs
210+
if args[attr.name] is not None
211+
)
212+
_regions = None
213+
_ods_successors = None
214+
self = args["self"]
215+
super(_OpView, self).__init__(
216+
self.OPERATION_NAME,
217+
self._ODS_REGIONS,
218+
self._ODS_OPERAND_SEGMENTS,
219+
self._ODS_RESULT_SEGMENTS,
220+
attributes=_attributes,
221+
results=_results,
222+
operands=_operands,
223+
successors=_ods_successors,
224+
regions=_regions,
225+
loc=args["loc"],
226+
ip=args["ip"],
227+
)
228+
229+
__init__.__signature__ = sig
230+
231+
for attr in attrs:
232+
setattr(
233+
_OpView,
234+
attr.name,
235+
property(lambda self, name=attr.name: self.attributes[name]),
236+
)
237+
238+
def value_range_getter(
239+
value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList],
240+
variadicity: Variadicity,
241+
):
242+
if variadicity == Variadicity.single:
243+
return value_range[0]
244+
if variadicity == Variadicity.optional:
245+
return value_range[0] if len(value_range) > 0 else None
246+
return value_range
247+
248+
for i, operand in enumerate(operands):
249+
if operand_segments:
250+
251+
def getter(self, i=i, operand=operand):
252+
operand_range = _ods_segmented_accessor(
253+
self.operation.operands,
254+
self.operation.attributes["operandSegmentSizes"],
255+
i,
256+
)
257+
return value_range_getter(operand_range, operand.variadicity)
258+
259+
setattr(_OpView, operand.name, property(getter))
260+
else:
261+
setattr(
262+
_OpView, operand.name, property(lambda self, i=i: self.operands[i])
263+
)
264+
for i, result in enumerate(self.results):
265+
if result_segments:
266+
267+
def getter(self, i=i, result=result):
268+
result_range = _ods_segmented_accessor(
269+
self.operation.results,
270+
self.operation.attributes["resultSegmentSizes"],
271+
i,
272+
)
273+
return value_range_getter(result_range, result.variadicity)
274+
275+
setattr(_OpView, result.name, property(getter))
276+
else:
277+
setattr(
278+
_OpView, result.name, property(lambda self, i=i: self.results[i])
279+
)
280+
281+
def _builder(*args, **kwargs) -> _OpView:
282+
return _OpView(*args, **kwargs)
283+
284+
_builder.__signature__ = _Signature(params[1:])
285+
286+
return _OpView, _builder
287+
288+
289+
class Dialect:
290+
def __init__(self, name: str):
291+
self.name = name
292+
self.operations: List[Operation] = []
293+
self.namespace = _SimpleNameSpace()
294+
295+
def _emit(self) -> None:
296+
d = _irdl.dialect(self.name)
297+
with _ods_ir.InsertionPoint(d.body):
298+
for op in self.operations:
299+
op._emit()
300+
301+
def _make_module(self) -> _ods_ir.Module:
302+
with _ods_ir.Location.unknown():
303+
m = _ods_ir.Module.create()
304+
with _ods_ir.InsertionPoint(m.body):
305+
self._emit()
306+
return m
307+
308+
def _make_dialect_class(self) -> type:
309+
class _Dialect(_ods_ir.Dialect):
310+
DIALECT_NAMESPACE = self.name
311+
312+
return _Dialect
313+
314+
def load(self) -> _SimpleNameSpace:
315+
_irdl.load_dialects(self._make_module())
316+
dialect_class = self._make_dialect_class()
317+
_ods_cext.register_dialect(dialect_class)
318+
for op in self.operations:
319+
_ods_cext.register_operation(dialect_class)(op.op_view)
320+
return self.namespace
321+
322+
def op(self, name: str) -> Callable[[type], type]:
323+
def decorator(cls: type) -> type:
324+
operands_and_attrs: List[Union[Operand, Attribute]] = []
325+
results: List[Result] = []
326+
327+
for field in cls.__dict__.values():
328+
if isinstance(field, Operand) or isinstance(field, Attribute):
329+
operands_and_attrs.append(field)
330+
elif isinstance(field, Result):
331+
results.append(field)
332+
333+
op_def = Operation(self.name, name, operands_and_attrs, results)
334+
op_view, builder = op_def._make_op_view_and_builder()
335+
setattr(op_def, "op_view", op_view)
336+
setattr(op_def, "builder", builder)
337+
self.operations.append(op_def)
338+
self.namespace.__dict__[cls.__name__] = op_view
339+
op_view.__name__ = cls.__name__
340+
self.namespace.__dict__[name.replace(".", "_")] = builder
341+
return cls
342+
343+
return decorator

0 commit comments

Comments
 (0)