Skip to content

Commit d231ed0

Browse files
Thibault-Pelletierjourdain
authored andcommitted
fix(typed_state): fix decoder with annotations
- Fix TypedState decoder when used with dataclasses in files containing from __future__ import annotations. Import triggers a lazy evaluation for the dataclass fields method and field.type contains strings instead of actual types. - Change default encode / decode behavior to raise a TypeError when Serialization fails instead of silently returning unchanged object.
1 parent 69e4548 commit d231ed0

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

tests/test_typed_state.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,21 @@ def test_handles_list_of_union(state):
419419
typed_state.data.my_optional_enum_list = [b_str, b_enum, None]
420420
assert typed_state.data.my_optional_enum_list == [b_enum, b_enum, None]
421421
assert state[typed_state.name.my_optional_enum_list] == [b_str, b_str, None]
422+
423+
424+
def test_failure_to_encode_raises_type_error(state):
425+
class RaiseEncode(DefaultEncoderDecoder):
426+
def encode(self, _obj):
427+
raise AssertionError()
428+
429+
def decode(self, _obj, _obj_type: type):
430+
raise AssertionError()
431+
432+
typed_state = TypedState(state, DataWithTypes, encoders=[RaiseEncode()])
433+
state.setdefault(typed_state.name.my_enum, MyEnum.A.value)
434+
435+
with pytest.raises(TypeError):
436+
typed_state.data.my_enum = MyEnum.A
437+
438+
with pytest.raises(TypeError):
439+
print(typed_state.data.my_enum)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations # triggers lazy evaluation of field.type
2+
3+
from dataclasses import dataclass
4+
from enum import Enum, auto
5+
6+
import pytest
7+
8+
from trame_server import Server
9+
from trame_server.utils.typed_state import TypedState
10+
11+
12+
class AnEnum(Enum):
13+
A = auto()
14+
B = auto()
15+
16+
17+
@dataclass
18+
class DataWithTypesAnnotations:
19+
my_enum: AnEnum
20+
21+
22+
@pytest.fixture
23+
def state():
24+
server = Server()
25+
server.state.ready()
26+
return server.state
27+
28+
29+
def test_is_compatible_with_from_future_annotations(state):
30+
typed_state = TypedState(state, DataWithTypesAnnotations)
31+
typed_state.data.my_enum = AnEnum.B
32+
assert typed_state.data.my_enum == AnEnum.B
33+
assert state[typed_state.name.my_enum] == AnEnum.B.value

trame_server/utils/typed_state.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
cast,
1818
get_args,
1919
get_origin,
20+
get_type_hints,
2021
)
2122
from uuid import UUID
2223

@@ -124,7 +125,9 @@ def encode(self, obj):
124125
val = self._try_serialize(encoder.encode, obj)
125126
if self.is_serialization_success(val):
126127
return val
127-
return obj
128+
129+
_error_msg = f"Failed to encode object {obj}. No appropriate encoder in {self._encoders}."
130+
raise TypeError(_error_msg)
128131

129132
@classmethod
130133
def _is_iterable(cls, obj):
@@ -140,7 +143,9 @@ def decode(self, obj, obj_type: type):
140143
val = self._try_decode(obj, obj_type)
141144
if self.is_serialization_success(val):
142145
return val
143-
return obj
146+
147+
_error_msg = f"Failed to decode object {obj} of type {obj_type}. No appropriate decoder in {self._encoders}."
148+
raise TypeError(_error_msg)
144149

145150
def _try_decode(self, obj, obj_type: type):
146151
for decode in self._decode_strategies():
@@ -361,14 +366,14 @@ def _create_state_proxy(
361366
"""
362367
encoder = encoder or CollectionEncoderDecoder(None)
363368

364-
def handler(state_id: str, field: Field):
369+
def handler(state_id: str, field: Field, field_type: type):
365370
return _ProxyField(
366371
state=state,
367372
state_id=state_id,
368373
name=field.name,
369374
default=field.default,
370375
default_factory=field.default_factory,
371-
field_type=field.type,
376+
field_type=field_type,
372377
state_encoder=encoder,
373378
)
374379

@@ -385,7 +390,7 @@ def _create_state_names_proxy(cls, dataclass_type: Type[T], *, namespace="") ->
385390
namespace prefix.
386391
"""
387392

388-
def handler(state_id: str, _field: Field):
393+
def handler(state_id: str, _field: Field, _field_type: type):
389394
return _NameField(state_id=state_id)
390395

391396
return cls._build_proxy_cls(dataclass_type, namespace, handler, "__ProxyName")
@@ -395,7 +400,7 @@ def _build_proxy_cls(
395400
cls,
396401
dataclass_type: Type[T],
397402
prefix: str,
398-
handler: Callable[[str, Field], Any],
403+
handler: Callable[[str, Field, type], Any],
399404
cls_suffix: str,
400405
proxy_field_dict: dict | None = None,
401406
) -> T:
@@ -415,14 +420,19 @@ def _build_proxy_cls(
415420
class_name = dataclass_type.__name__
416421
inner_field_dict = {}
417422
prefix = f"{prefix}__{class_name}" if prefix else class_name
423+
424+
# Use type hints instead of field.type to avoid lazy evaluation of field.type when used in files containing
425+
# from __future__ import annotations header.
426+
field_types = get_type_hints(dataclass_type)
418427
for f in fields(dataclass_type):
419428
state_id = f"{prefix}__{f.name}"
420-
if is_dataclass(f.type):
429+
f_type = field_types[f.name]
430+
if is_dataclass(f_type):
421431
field = cls._build_proxy_cls(
422-
f.type, state_id, handler, cls_suffix, inner_field_dict
432+
f_type, state_id, handler, cls_suffix, inner_field_dict
423433
)
424434
else:
425-
field = handler(state_id, f)
435+
field = handler(state_id, f, f_type)
426436

427437
inner_field_dict[cls.get_state_id(field, state_id)] = field
428438
namespace[f.name] = field

0 commit comments

Comments
 (0)