|
| 1 | +from types import MethodType |
| 2 | +import pydantic |
| 3 | +from amaranth.lib import meta |
| 4 | +from typing import Any, Annotated, NamedTuple, Self, TypeVar |
| 5 | +from typing_extensions import TypedDict, is_typeddict |
| 6 | +_T_TypedDict = TypeVar('_T_TypedDict') |
| 7 | + |
| 8 | +def amaranth_annotate(modeltype: type['_T_TypedDict'], schema_id: str, member='__chipflow_annotation__', decorate_object = False): |
| 9 | + if not is_typeddict(modeltype): |
| 10 | + raise TypeError(f'''amaranth_annotate must be passed a TypedDict, not {modeltype}''') |
| 11 | + |
| 12 | + # interesting pydantic issue gets hit if arbitrary_types_allowed is False |
| 13 | + if hasattr(modeltype, '__pydantic_config__'): |
| 14 | + config = getattr(modeltype, '__pydantic_config__') |
| 15 | + config['arbitrary_types_allowed'] = True |
| 16 | + else: |
| 17 | + config = pydantic.ConfigDict() |
| 18 | + config['arbitrary_types_allowed'] = True |
| 19 | + setattr(modeltype, '__pydantic_config__', config) |
| 20 | + PydanticModel = pydantic.TypeAdapter(modeltype) |
| 21 | + |
| 22 | + def annotation_schema(): |
| 23 | + schema = PydanticModel.json_schema() |
| 24 | + schema['$schema'] = 'https://json-schema.org/draft/2020-12/schema' |
| 25 | + schema['$id'] = schema_id |
| 26 | + return schema |
| 27 | + |
| 28 | + class Annotation: |
| 29 | + 'Generated annotation class' |
| 30 | + schema = annotation_schema() |
| 31 | + |
| 32 | + def __init__(self, parent): |
| 33 | + self.parent = parent |
| 34 | + |
| 35 | + def origin(self): |
| 36 | + return self.parent |
| 37 | + |
| 38 | + def as_json(self): |
| 39 | + return PydanticModel.dump_python(getattr(self.parent, member)) |
| 40 | + |
| 41 | + def decorate_class(klass): |
| 42 | + if hasattr(klass, 'annotations'): |
| 43 | + old_annotations = klass.annotations |
| 44 | + else: |
| 45 | + old_annotations = None |
| 46 | + |
| 47 | + def annotations(self, obj): |
| 48 | + if old_annotations: |
| 49 | + annotations = old_annotations(self, obj) |
| 50 | + else: |
| 51 | + annotations = super(klass, obj).annotations(obj) |
| 52 | + annotation = Annotation(self) |
| 53 | + return annotations + (annotation,) |
| 54 | + |
| 55 | + klass.annotations = annotations |
| 56 | + return klass |
| 57 | + |
| 58 | + def decorate_obj(obj): |
| 59 | + if hasattr(obj, 'annotations'): |
| 60 | + old_annotations = obj.annotations |
| 61 | + else: |
| 62 | + old_annotations = None |
| 63 | + |
| 64 | + def annotations(self = None, origin = None): |
| 65 | + if old_annotations: |
| 66 | + annotations = old_annotations(origin) |
| 67 | + else: |
| 68 | + annotations = super(obj.__class__, obj).annotations(obj) |
| 69 | + annotation = Annotation(self) |
| 70 | + return annotations + (annotation,) |
| 71 | + |
| 72 | + setattr(obj, 'annotations', MethodType(annotations, obj)) |
| 73 | + return obj |
| 74 | + |
| 75 | + if decorate_object: |
| 76 | + return decorate_obj |
| 77 | + else: |
| 78 | + return decorate_class |
| 79 | + |
0 commit comments