Skip to content

Commit bd08bc7

Browse files
author
Paolo Tranquilli
committed
Rust: address review
1 parent 248eb7f commit bd08bc7

File tree

4 files changed

+41
-10
lines changed

4 files changed

+41
-10
lines changed

misc/codegen/lib/schemadefs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import (
22
Callable as _Callable,
33
Dict as _Dict,
4-
List as _List,
4+
Iterable as _Iterable,
55
ClassVar as _ClassVar,
66
)
77
from misc.codegen.lib import schema as _schema
@@ -279,7 +279,7 @@ def __or__(self, other: _schema.PropertyModifier):
279279
drop = object()
280280

281281

282-
def annotate(annotated_cls: type, add_bases: _List[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
282+
def annotate(annotated_cls: type, add_bases: _Iterable[type] | None = None, replace_bases: _Dict[type, type] | None = None) -> _Callable[[type], _PropertyAnnotation]:
283283
"""
284284
Add or modify schema annotations after a class has been defined previously.
285285
@@ -297,7 +297,7 @@ def decorator(cls: type) -> _PropertyAnnotation:
297297
if replace_bases:
298298
annotated_cls.__bases__ = tuple(replace_bases.get(b, b) for b in annotated_cls.__bases__)
299299
if add_bases:
300-
annotated_cls.__bases__ = tuple(annotated_cls.__bases__) + tuple(add_bases)
300+
annotated_cls.__bases__ += tuple(add_bases)
301301
for a in dir(cls):
302302
if a.startswith(_schema.inheritable_pragma_prefix):
303303
setattr(annotated_cls, a, getattr(cls, a))

misc/codegen/test/test_schemaloader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -914,6 +914,36 @@ class _:
914914
}
915915

916916

917+
def test_annotate_add_bases():
918+
@load
919+
class data:
920+
class Root:
921+
pass
922+
923+
class A(Root):
924+
pass
925+
926+
class B(Root):
927+
pass
928+
929+
class C(Root):
930+
pass
931+
932+
class Derived(A):
933+
pass
934+
935+
@defs.annotate(Derived, add_bases=(B, C))
936+
class _:
937+
pass
938+
assert data.classes == {
939+
"Root": schema.Class("Root", derived={"A", "B", "C"}),
940+
"A": schema.Class("A", bases=["Root"], derived={"Derived"}),
941+
"B": schema.Class("B", bases=["Root"], derived={"Derived"}),
942+
"C": schema.Class("C", bases=["Root"], derived={"Derived"}),
943+
"Derived": schema.Class("Derived", bases=["A", "B", "C"]),
944+
}
945+
946+
917947
def test_annotate_drop_field():
918948
@load
919949
class data:

rust/schema/annotations.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,13 +1741,6 @@ class _:
17411741
```
17421742
"""
17431743

1744-
class Callable(AstNode):
1745-
"""
1746-
A callable. Either a `Function` or a `ClosureExpr`.
1747-
"""
1748-
param_list: optional["ParamList"] | child
1749-
attrs: list["Attr"] | child
1750-
17511744
@annotate(Function, add_bases=[Callable])
17521745
class _:
17531746
param_list: drop

rust/schema/prelude.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,11 @@ class Unimplemented(Unextracted):
6363
The base class for unimplemented nodes. This is used to mark nodes that are not yet extracted.
6464
"""
6565
pass
66+
67+
68+
class Callable(AstNode):
69+
"""
70+
A callable. Either a `Function` or a `ClosureExpr`.
71+
"""
72+
param_list: optional["ParamList"] | child
73+
attrs: list["Attr"] | child

0 commit comments

Comments
 (0)