Skip to content

Commit 792e244

Browse files
committed
feat(v2): improve encoding and type checking
1 parent c75a36b commit 792e244

File tree

3 files changed

+146
-47
lines changed

3 files changed

+146
-47
lines changed

src/trame_dataclass/v2.py

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import weakref
66
from collections.abc import Awaitable, Sequence
77
from dataclasses import dataclass, field
8+
from enum import Enum, auto
89
from typing import Any, Callable, Union, get_args, get_origin
910

1011
from loguru import logger
@@ -117,6 +118,17 @@ def _save_field(name, src, dst, encoder=None):
117118
dst[name] = value
118119

119120

121+
def _setup_class_fields(owner):
122+
# set
123+
for key in ["FIELD_NAMES", "DATACLASS_NAMES", "CLIENT_NAMES", "CLIENT_ONLY_NAMES"]:
124+
if not hasattr(owner, key):
125+
setattr(owner, key, set())
126+
# dict
127+
for key in ["ENCODERS", "TYPE_CHECKING"]:
128+
if not hasattr(owner, key):
129+
setattr(owner, key, {})
130+
131+
120132
# -----------------------------------------------------------------------------
121133
# Dataclass builder
122134
# -----------------------------------------------------------------------------
@@ -126,6 +138,12 @@ def __init__(self, encoder, decoder):
126138
self.decoder = decoder
127139

128140

141+
class TypeValidation(Enum):
142+
STRICT = auto()
143+
WARNING = auto()
144+
SKIP = auto()
145+
146+
129147
class StateDataModel:
130148
def __init__(self, trame_server=None, **kwargs):
131149
self.__id = _next_id()
@@ -232,7 +250,6 @@ def clear_watchers(self):
232250

233251
def clone(self):
234252
other = self.__class__(trame_server=self.server)
235-
print(other)
236253
state = getattr(self, "_server_state", {})
237254
other.update(**state)
238255
return other
@@ -362,18 +379,38 @@ def decode_dataclass_item(item):
362379
def encode_dataclass_list(items):
363380
if items is None:
364381
return None
365-
print("encode list", items)
382+
# print("encode list", items)
366383
return [item._id for item in items]
367384

368385

369386
def decode_dataclass_list(items):
370387
# print("decode_dataclass_list", items)
371388
if items is None:
372389
return None
373-
print("decode list", items)
390+
# print("decode list", items)
374391
return list(map(get_instance, items))
375392

376393

394+
def decode_dataclass_set(items):
395+
# print("decode_dataclass_list", items)
396+
if items is None:
397+
return None
398+
# print("decode list", items)
399+
return set(map(get_instance, items))
400+
401+
402+
def encode_set(items):
403+
if items is None:
404+
return None
405+
return list(items)
406+
407+
408+
def decode_set(items):
409+
if items is None:
410+
return None
411+
return set(items)
412+
413+
377414
def encode_dataclass_dict(data):
378415
if data is None:
379416
return None
@@ -396,6 +433,7 @@ def decode_dataclass_dict(data):
396433
"ServerOnly",
397434
"StateDataModel",
398435
"Sync",
436+
"TypeValidation",
399437
"get_instance",
400438
"watch",
401439
]
@@ -412,9 +450,9 @@ def __init__(
412450
self,
413451
_type,
414452
default=None,
415-
convert=None,
416-
has_dataclass=False,
417-
type_checking="warning", # error, warning, ignore
453+
convert: FieldEncoder = None,
454+
has_dataclass: bool = False,
455+
type_checking: TypeValidation = TypeValidation.WARNING,
418456
):
419457
self._type_checking = type_checking
420458
self._type = get_origin(_type) or _type
@@ -433,6 +471,9 @@ def __init__(
433471
if self._type is list:
434472
encoder = encode_dataclass_list
435473
decoder = decode_dataclass_list
474+
elif self._type is set:
475+
encoder = encode_dataclass_list
476+
decoder = decode_dataclass_set
436477
elif self._type is dict:
437478
encoder = encode_dataclass_dict
438479
decoder = decode_dataclass_dict
@@ -442,13 +483,11 @@ def __init__(
442483

443484
self._convert = FieldEncoder(encoder, decoder)
444485

445-
def __set_name__(self, owner, name):
446-
if not hasattr(owner, "FIELD_NAMES"):
447-
owner.FIELD_NAMES = set()
448-
449-
if not hasattr(owner, "TYPE_CHECKING"):
450-
owner.TYPE_CHECKING = {}
486+
if not self._convert and self._type is set:
487+
self._convert = FieldEncoder(encode_set, decode_set)
451488

489+
def __set_name__(self, owner, name):
490+
_setup_class_fields(owner)
452491
self._name = name
453492
owner.TYPE_CHECKING[name] = self._type_checking
454493
owner.FIELD_NAMES.add(name)
@@ -461,37 +500,25 @@ def __get__(self, instance, owner):
461500
def __set__(self, instance, value):
462501
type_check = instance.TYPE_CHECKING[self._name]
463502
if (
464-
type_check in {"error", "warning"}
503+
type_check in {TypeValidation.STRICT, TypeValidation.WARNING}
465504
and value is not None
466505
and not isinstance(value, self._type)
467506
):
468-
msg = f"{self._name} must be {self._type} instead of {type(value)}"
469-
if type_check == "error":
507+
msg = f"{self._name} must be {self._type} instead of {type(value)} for class {instance.__class__}"
508+
if type_check == TypeValidation.STRICT:
470509
raise TypeError(msg)
471510

472511
logger.warning(msg)
473512

474-
instance._dirty_set.add(self._name)
475-
instance._server_state[self._name] = value
476-
instance._on_dirty()
513+
if instance._server_state.get(self._name) != value:
514+
instance._dirty_set.add(self._name)
515+
instance._server_state[self._name] = value
516+
instance._on_dirty()
477517

478518

479519
class Sync(ServerOnly):
480520
def __set_name__(self, owner, name):
481-
if not hasattr(owner, "FIELD_NAMES"):
482-
owner.FIELD_NAMES = set()
483-
484-
if not hasattr(owner, "TYPE_CHECKING"):
485-
owner.TYPE_CHECKING = {}
486-
487-
if not hasattr(owner, "CLIENT_NAMES"):
488-
owner.CLIENT_NAMES = set()
489-
490-
if not hasattr(owner, "ENCODERS"):
491-
owner.ENCODERS = {}
492-
493-
if not hasattr(owner, "DATACLASS_NAMES"):
494-
owner.DATACLASS_NAMES = set()
521+
_setup_class_fields(owner)
495522

496523
if self._has_dataclass:
497524
owner.DATACLASS_NAMES.add(name)
@@ -507,18 +534,7 @@ def __set_name__(self, owner, name):
507534

508535
class ClientOnly(ServerOnly):
509536
def __set_name__(self, owner, name):
510-
if not hasattr(owner, "FIELD_NAMES"):
511-
owner.FIELD_NAMES = set()
512-
513-
if not hasattr(owner, "TYPE_CHECKING"):
514-
owner.TYPE_CHECKING = {}
515-
516-
if not hasattr(owner, "CLIENT_NAMES"):
517-
owner.CLIENT_NAMES = set()
518-
519-
if not hasattr(owner, "CLIENT_ONLY_NAMES"):
520-
owner.CLIENT_ONLY_NAMES = set()
521-
537+
_setup_class_fields(owner)
522538
self._name = name
523539
owner.TYPE_CHECKING[name] = self._type_checking
524540
owner.FIELD_NAMES.add(name)

tests/test_dataclass.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
from contextlib import suppress
32
from pathlib import Path
43

54
import pytest
@@ -41,12 +40,12 @@ class Composite(StateDataModel):
4140

4241

4342
def test_complex_type():
44-
with suppress(NonSerializableType):
43+
with pytest.raises(NonSerializableType):
4544

4645
class ErrorData(StateDataModel):
4746
path: Path
4847

49-
pytest.fail("Should trigger a NonSerializableType exception")
48+
# pytest.fail("Should trigger a NonSerializableType exception")
5049

5150

5251
def test_watch():

tests/test_v2.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import pytest
2+
3+
from trame_dataclass.v2 import (
4+
ServerOnly,
5+
StateDataModel,
6+
Sync,
7+
TypeValidation,
8+
)
9+
10+
11+
def test_input_validation():
12+
class BasicTypeValidation(StateDataModel):
13+
a = ServerOnly(int, 0, type_checking=TypeValidation.STRICT)
14+
b = ServerOnly(float, 0.0, type_checking=TypeValidation.STRICT)
15+
c = ServerOnly(str, "", type_checking=TypeValidation.STRICT)
16+
d = ServerOnly(dict, dict, type_checking=TypeValidation.STRICT)
17+
e = ServerOnly(list[int], list, type_checking=TypeValidation.STRICT)
18+
f = ServerOnly(set, set, type_checking=TypeValidation.STRICT)
19+
g = ServerOnly(str, "", type_checking=TypeValidation.SKIP)
20+
h = ServerOnly(str, "", type_checking=TypeValidation.WARNING)
21+
22+
data = BasicTypeValidation()
23+
data.a = 1
24+
data.b = 1.1
25+
data.c = "hello"
26+
data.d = {"a": 1}
27+
data.e = [1, 2, 3]
28+
data.f = {"a", "b", "c"}
29+
data.g = 5
30+
data.h = 5
31+
32+
with pytest.raises(TypeError):
33+
data.a = 1.2
34+
35+
with pytest.raises(TypeError):
36+
data.b = 5
37+
38+
with pytest.raises(TypeError):
39+
data.c = 0
40+
41+
with pytest.raises(TypeError):
42+
data.d = []
43+
44+
with pytest.raises(TypeError):
45+
data.e = {}
46+
47+
with pytest.raises(TypeError):
48+
data.f = []
49+
50+
51+
def test_serialization():
52+
class Dummy(StateDataModel):
53+
a = Sync(int, 0)
54+
55+
class BasicSerial(StateDataModel):
56+
a = Sync(int, 123)
57+
b = Sync(Dummy, has_dataclass=True)
58+
c = Sync(list[Dummy], list, has_dataclass=True)
59+
d = Sync(dict[str, Dummy], dict, has_dataclass=True)
60+
e = Sync(set, set)
61+
f = Sync(set, set, has_dataclass=True)
62+
63+
data = BasicSerial()
64+
data.b = Dummy()
65+
data.c = [Dummy(), Dummy()]
66+
data.d = {"a": Dummy(), "b": Dummy()}
67+
data.e = {"a", "b", "c"}
68+
data.f = {Dummy(), Dummy(), Dummy()}
69+
70+
state = data.client_state
71+
72+
assert state.get("a") == 123
73+
assert state.get("b") == data.b._id
74+
assert state.get("c") == [i._id for i in data.c]
75+
assert state.get("d") == {k: v._id for k, v in data.d.items()}
76+
77+
set_as_list = state.get("e")
78+
assert isinstance(set_as_list, list)
79+
assert set(set_as_list) == {"a", "b", "c"}
80+
81+
set_as_list = state.get("f")
82+
assert isinstance(set_as_list, list)
83+
for obj in data.f:
84+
assert obj._id in set_as_list

0 commit comments

Comments
 (0)