|
| 1 | +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 2 | +# See https://llvm.org/LICENSE.txt for license information. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | + |
| 5 | +from ...dialects import irdl as _irdl |
| 6 | +from .._ods_common import ( |
| 7 | + _cext as _ods_cext, |
| 8 | + segmented_accessor as _ods_segmented_accessor, |
| 9 | +) |
| 10 | +from . import Variadicity |
| 11 | +from typing import Dict, List, Union, Callable, Tuple |
| 12 | +from dataclasses import dataclass |
| 13 | +from inspect import Parameter as _Parameter, Signature as _Signature |
| 14 | +from types import SimpleNamespace as _SimpleNameSpace |
| 15 | + |
| 16 | +_ods_ir = _ods_cext.ir |
| 17 | + |
| 18 | + |
| 19 | +class ConstraintExpr: |
| 20 | + def _lower(self, ctx: "ConstraintLoweringContext") -> _ods_ir.Value: |
| 21 | + raise NotImplementedError() |
| 22 | + |
| 23 | + def __or__(self, other: "ConstraintExpr") -> "ConstraintExpr": |
| 24 | + return AnyOf(self, other) |
| 25 | + |
| 26 | + def __and__(self, other: "ConstraintExpr") -> "ConstraintExpr": |
| 27 | + return AllOf(self, other) |
| 28 | + |
| 29 | + |
| 30 | +class ConstraintLoweringContext: |
| 31 | + def __init__(self): |
| 32 | + # Cache so that the same ConstraintExpr instance reuses its SSA value. |
| 33 | + self._cache: Dict[int, _ods_ir.Value] = {} |
| 34 | + |
| 35 | + def lower(self, expr: ConstraintExpr) -> _ods_ir.Value: |
| 36 | + key = id(expr) |
| 37 | + if key in self._cache: |
| 38 | + return self._cache[key] |
| 39 | + v = expr._lower(self) |
| 40 | + self._cache[key] = v |
| 41 | + return v |
| 42 | + |
| 43 | + |
| 44 | +class Is(ConstraintExpr): |
| 45 | + def __init__(self, attr: _ods_ir.Attribute): |
| 46 | + self.attr = attr |
| 47 | + |
| 48 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 49 | + return _irdl.is_(self.attr) |
| 50 | + |
| 51 | + |
| 52 | +class IsType(Is): |
| 53 | + def __init__(self, typ: _ods_ir.Type): |
| 54 | + super().__init__(_ods_ir.TypeAttr.get(typ)) |
| 55 | + |
| 56 | + |
| 57 | +class AnyOf(ConstraintExpr): |
| 58 | + def __init__(self, *exprs: ConstraintExpr): |
| 59 | + self.exprs = exprs |
| 60 | + |
| 61 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 62 | + return _irdl.any_of(ctx.lower(expr) for expr in self.exprs) |
| 63 | + |
| 64 | + |
| 65 | +class AllOf(ConstraintExpr): |
| 66 | + def __init__(self, *exprs: ConstraintExpr): |
| 67 | + self.exprs = exprs |
| 68 | + |
| 69 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 70 | + return _irdl.all_of(ctx.lower(expr) for expr in self.exprs) |
| 71 | + |
| 72 | + |
| 73 | +class Any(ConstraintExpr): |
| 74 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 75 | + return _irdl.any() |
| 76 | + |
| 77 | + |
| 78 | +class BaseName(ConstraintExpr): |
| 79 | + def __init__(self, name: str): |
| 80 | + self.name = name |
| 81 | + |
| 82 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 83 | + return _irdl.base(base_name=self.name) |
| 84 | + |
| 85 | + |
| 86 | +class BaseRef(ConstraintExpr): |
| 87 | + def __init__(self, ref): |
| 88 | + self.ref = ref |
| 89 | + |
| 90 | + def _lower(self, ctx: ConstraintLoweringContext) -> _ods_ir.Value: |
| 91 | + return _irdl.base(base_ref=self.ref) |
| 92 | + |
| 93 | + |
| 94 | +class FieldDef: |
| 95 | + def __set_name__(self, owner, name: str): |
| 96 | + self.name = name |
| 97 | + |
| 98 | + |
| 99 | +@dataclass |
| 100 | +class Operand(FieldDef): |
| 101 | + constraint: ConstraintExpr |
| 102 | + variadicity: Variadicity = Variadicity.single |
| 103 | + |
| 104 | + |
| 105 | +@dataclass |
| 106 | +class Result(FieldDef): |
| 107 | + constraint: ConstraintExpr |
| 108 | + variadicity: Variadicity = Variadicity.single |
| 109 | + |
| 110 | + |
| 111 | +@dataclass |
| 112 | +class Attribute(FieldDef): |
| 113 | + constraint: ConstraintExpr |
| 114 | + |
| 115 | + def __post_init__(self): |
| 116 | + # just for unified processing, |
| 117 | + # currently optional attribute is not supported by IRDL |
| 118 | + self.variadicity = Variadicity.single |
| 119 | + |
| 120 | + |
| 121 | +@dataclass |
| 122 | +class Operation: |
| 123 | + dialect_name: str |
| 124 | + name: str |
| 125 | + # We store operands and attributes into one list to maintain relative orders |
| 126 | + # among them for generating OpView class. |
| 127 | + operands_and_attrs: List[Union[Operand, Attribute]] |
| 128 | + results: List[Result] |
| 129 | + |
| 130 | + def _emit(self) -> None: |
| 131 | + op = _irdl.operation_(self.name) |
| 132 | + ctx = ConstraintLoweringContext() |
| 133 | + |
| 134 | + operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] |
| 135 | + attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] |
| 136 | + |
| 137 | + with _ods_ir.InsertionPoint(op.body): |
| 138 | + if operands: |
| 139 | + _irdl.operands_( |
| 140 | + [ctx.lower(i.constraint) for i in operands], |
| 141 | + [i.name for i in operands], |
| 142 | + [i.variadicity for i in operands], |
| 143 | + ) |
| 144 | + if attrs: |
| 145 | + _irdl.attributes_( |
| 146 | + [ctx.lower(i.constraint) for i in attrs], |
| 147 | + [i.name for i in attrs], |
| 148 | + ) |
| 149 | + if self.results: |
| 150 | + _irdl.results_( |
| 151 | + [ctx.lower(i.constraint) for i in self.results], |
| 152 | + [i.name for i in self.results], |
| 153 | + [i.variadicity for i in self.results], |
| 154 | + ) |
| 155 | + |
| 156 | + def _make_op_view_and_builder(self) -> Tuple[type, Callable]: |
| 157 | + operands = [i for i in self.operands_and_attrs if isinstance(i, Operand)] |
| 158 | + attrs = [i for i in self.operands_and_attrs if isinstance(i, Attribute)] |
| 159 | + |
| 160 | + def variadicity_to_segment(variadicity: Variadicity) -> int: |
| 161 | + if variadicity == Variadicity.variadic: |
| 162 | + return -1 |
| 163 | + if variadicity == Variadicity.optional: |
| 164 | + return 0 |
| 165 | + return 1 |
| 166 | + |
| 167 | + operand_segments = None |
| 168 | + if any(i.variadicity != Variadicity.single for i in operands): |
| 169 | + operand_segments = [variadicity_to_segment(i.variadicity) for i in operands] |
| 170 | + |
| 171 | + result_segments = None |
| 172 | + if any(i.variadicity != Variadicity.single for i in self.results): |
| 173 | + result_segments = [ |
| 174 | + variadicity_to_segment(i.variadicity) for i in self.results |
| 175 | + ] |
| 176 | + |
| 177 | + args = self.results + self.operands_and_attrs |
| 178 | + positional_args = [ |
| 179 | + i.name for i in args if i.variadicity != Variadicity.optional |
| 180 | + ] |
| 181 | + optional_args = [i.name for i in args if i.variadicity == Variadicity.optional] |
| 182 | + |
| 183 | + params = [_Parameter("self", _Parameter.POSITIONAL_ONLY)] |
| 184 | + for i in positional_args: |
| 185 | + params.append(_Parameter(i, _Parameter.POSITIONAL_OR_KEYWORD)) |
| 186 | + for i in optional_args: |
| 187 | + params.append(_Parameter(i, _Parameter.KEYWORD_ONLY, default=None)) |
| 188 | + params.append(_Parameter("loc", _Parameter.KEYWORD_ONLY, default=None)) |
| 189 | + params.append(_Parameter("ip", _Parameter.KEYWORD_ONLY, default=None)) |
| 190 | + |
| 191 | + sig = _Signature(params) |
| 192 | + op = self |
| 193 | + |
| 194 | + class _OpView(_ods_ir.OpView): |
| 195 | + OPERATION_NAME = f"{op.dialect_name}.{op.name}" |
| 196 | + _ODS_REGIONS = (0, True) |
| 197 | + _ODS_OPERAND_SEGMENTS = operand_segments |
| 198 | + _ODS_RESULT_SEGMENTS = result_segments |
| 199 | + |
| 200 | + def __init__(*args, **kwargs): |
| 201 | + bound = sig.bind(*args, **kwargs) |
| 202 | + bound.apply_defaults() |
| 203 | + args = bound.arguments |
| 204 | + |
| 205 | + _operands = [args[operand.name] for operand in operands] |
| 206 | + _results = [args[result.name] for result in op.results] |
| 207 | + _attributes = dict( |
| 208 | + (attr.name, args[attr.name]) |
| 209 | + for attr in attrs |
| 210 | + if args[attr.name] is not None |
| 211 | + ) |
| 212 | + _regions = None |
| 213 | + _ods_successors = None |
| 214 | + self = args["self"] |
| 215 | + super(_OpView, self).__init__( |
| 216 | + self.OPERATION_NAME, |
| 217 | + self._ODS_REGIONS, |
| 218 | + self._ODS_OPERAND_SEGMENTS, |
| 219 | + self._ODS_RESULT_SEGMENTS, |
| 220 | + attributes=_attributes, |
| 221 | + results=_results, |
| 222 | + operands=_operands, |
| 223 | + successors=_ods_successors, |
| 224 | + regions=_regions, |
| 225 | + loc=args["loc"], |
| 226 | + ip=args["ip"], |
| 227 | + ) |
| 228 | + |
| 229 | + __init__.__signature__ = sig |
| 230 | + |
| 231 | + for attr in attrs: |
| 232 | + setattr( |
| 233 | + _OpView, |
| 234 | + attr.name, |
| 235 | + property(lambda self, name=attr.name: self.attributes[name]), |
| 236 | + ) |
| 237 | + |
| 238 | + def value_range_getter( |
| 239 | + value_range: Union[_ods_ir.OpOperandList, _ods_ir.OpResultList], |
| 240 | + variadicity: Variadicity, |
| 241 | + ): |
| 242 | + if variadicity == Variadicity.single: |
| 243 | + return value_range[0] |
| 244 | + if variadicity == Variadicity.optional: |
| 245 | + return value_range[0] if len(value_range) > 0 else None |
| 246 | + return value_range |
| 247 | + |
| 248 | + for i, operand in enumerate(operands): |
| 249 | + if operand_segments: |
| 250 | + |
| 251 | + def getter(self, i=i, operand=operand): |
| 252 | + operand_range = _ods_segmented_accessor( |
| 253 | + self.operation.operands, |
| 254 | + self.operation.attributes["operandSegmentSizes"], |
| 255 | + i, |
| 256 | + ) |
| 257 | + return value_range_getter(operand_range, operand.variadicity) |
| 258 | + |
| 259 | + setattr(_OpView, operand.name, property(getter)) |
| 260 | + else: |
| 261 | + setattr( |
| 262 | + _OpView, operand.name, property(lambda self, i=i: self.operands[i]) |
| 263 | + ) |
| 264 | + for i, result in enumerate(self.results): |
| 265 | + if result_segments: |
| 266 | + |
| 267 | + def getter(self, i=i, result=result): |
| 268 | + result_range = _ods_segmented_accessor( |
| 269 | + self.operation.results, |
| 270 | + self.operation.attributes["resultSegmentSizes"], |
| 271 | + i, |
| 272 | + ) |
| 273 | + return value_range_getter(result_range, result.variadicity) |
| 274 | + |
| 275 | + setattr(_OpView, result.name, property(getter)) |
| 276 | + else: |
| 277 | + setattr( |
| 278 | + _OpView, result.name, property(lambda self, i=i: self.results[i]) |
| 279 | + ) |
| 280 | + |
| 281 | + def _builder(*args, **kwargs) -> _OpView: |
| 282 | + return _OpView(*args, **kwargs) |
| 283 | + |
| 284 | + _builder.__signature__ = _Signature(params[1:]) |
| 285 | + |
| 286 | + return _OpView, _builder |
| 287 | + |
| 288 | + |
| 289 | +class Dialect: |
| 290 | + def __init__(self, name: str): |
| 291 | + self.name = name |
| 292 | + self.operations: List[Operation] = [] |
| 293 | + self.namespace = _SimpleNameSpace() |
| 294 | + |
| 295 | + def _emit(self) -> None: |
| 296 | + d = _irdl.dialect(self.name) |
| 297 | + with _ods_ir.InsertionPoint(d.body): |
| 298 | + for op in self.operations: |
| 299 | + op._emit() |
| 300 | + |
| 301 | + def _make_module(self) -> _ods_ir.Module: |
| 302 | + with _ods_ir.Location.unknown(): |
| 303 | + m = _ods_ir.Module.create() |
| 304 | + with _ods_ir.InsertionPoint(m.body): |
| 305 | + self._emit() |
| 306 | + return m |
| 307 | + |
| 308 | + def _make_dialect_class(self) -> type: |
| 309 | + class _Dialect(_ods_ir.Dialect): |
| 310 | + DIALECT_NAMESPACE = self.name |
| 311 | + |
| 312 | + return _Dialect |
| 313 | + |
| 314 | + def load(self) -> _SimpleNameSpace: |
| 315 | + _irdl.load_dialects(self._make_module()) |
| 316 | + dialect_class = self._make_dialect_class() |
| 317 | + _ods_cext.register_dialect(dialect_class) |
| 318 | + for op in self.operations: |
| 319 | + _ods_cext.register_operation(dialect_class)(op.op_view) |
| 320 | + return self.namespace |
| 321 | + |
| 322 | + def op(self, name: str) -> Callable[[type], type]: |
| 323 | + def decorator(cls: type) -> type: |
| 324 | + operands_and_attrs: List[Union[Operand, Attribute]] = [] |
| 325 | + results: List[Result] = [] |
| 326 | + |
| 327 | + for field in cls.__dict__.values(): |
| 328 | + if isinstance(field, Operand) or isinstance(field, Attribute): |
| 329 | + operands_and_attrs.append(field) |
| 330 | + elif isinstance(field, Result): |
| 331 | + results.append(field) |
| 332 | + |
| 333 | + op_def = Operation(self.name, name, operands_and_attrs, results) |
| 334 | + op_view, builder = op_def._make_op_view_and_builder() |
| 335 | + setattr(op_def, "op_view", op_view) |
| 336 | + setattr(op_def, "builder", builder) |
| 337 | + self.operations.append(op_def) |
| 338 | + self.namespace.__dict__[cls.__name__] = op_view |
| 339 | + op_view.__name__ = cls.__name__ |
| 340 | + self.namespace.__dict__[name.replace(".", "_")] = builder |
| 341 | + return cls |
| 342 | + |
| 343 | + return decorator |
0 commit comments