55import weakref
66from collections .abc import Awaitable , Sequence
77from dataclasses import dataclass , field
8+ from enum import Enum , auto
89from typing import Any , Callable , Union , get_args , get_origin
910
1011from loguru import logger
@@ -117,6 +118,17 @@ def _save_field(name, src, dst, encoder=None):
117118 dst [name ] = value
118119
119120
121+ def _setup_class_fields (owner ):
122+ # set
123+ for key in ["FIELD_NAMES" , "DATACLASS_NAMES" , "CLIENT_NAMES" , "CLIENT_ONLY_NAMES" ]:
124+ if not hasattr (owner , key ):
125+ setattr (owner , key , set ())
126+ # dict
127+ for key in ["ENCODERS" , "TYPE_CHECKING" ]:
128+ if not hasattr (owner , key ):
129+ setattr (owner , key , {})
130+
131+
120132# -----------------------------------------------------------------------------
121133# Dataclass builder
122134# -----------------------------------------------------------------------------
@@ -126,6 +138,12 @@ def __init__(self, encoder, decoder):
126138 self .decoder = decoder
127139
128140
141+ class TypeValidation (Enum ):
142+ STRICT = auto ()
143+ WARNING = auto ()
144+ SKIP = auto ()
145+
146+
129147class StateDataModel :
130148 def __init__ (self , trame_server = None , ** kwargs ):
131149 self .__id = _next_id ()
@@ -232,7 +250,6 @@ def clear_watchers(self):
232250
233251 def clone (self ):
234252 other = self .__class__ (trame_server = self .server )
235- print (other )
236253 state = getattr (self , "_server_state" , {})
237254 other .update (** state )
238255 return other
@@ -362,18 +379,38 @@ def decode_dataclass_item(item):
362379def encode_dataclass_list (items ):
363380 if items is None :
364381 return None
365- print ("encode list" , items )
382+ # print("encode list", items)
366383 return [item ._id for item in items ]
367384
368385
369386def decode_dataclass_list (items ):
370387 # print("decode_dataclass_list", items)
371388 if items is None :
372389 return None
373- print ("decode list" , items )
390+ # print("decode list", items)
374391 return list (map (get_instance , items ))
375392
376393
394+ def decode_dataclass_set (items ):
395+ # print("decode_dataclass_list", items)
396+ if items is None :
397+ return None
398+ # print("decode list", items)
399+ return set (map (get_instance , items ))
400+
401+
402+ def encode_set (items ):
403+ if items is None :
404+ return None
405+ return list (items )
406+
407+
408+ def decode_set (items ):
409+ if items is None :
410+ return None
411+ return set (items )
412+
413+
377414def encode_dataclass_dict (data ):
378415 if data is None :
379416 return None
@@ -396,6 +433,7 @@ def decode_dataclass_dict(data):
396433 "ServerOnly" ,
397434 "StateDataModel" ,
398435 "Sync" ,
436+ "TypeValidation" ,
399437 "get_instance" ,
400438 "watch" ,
401439]
@@ -412,9 +450,9 @@ def __init__(
412450 self ,
413451 _type ,
414452 default = None ,
415- convert = None ,
416- has_dataclass = False ,
417- type_checking = "warning" , # error, warning, ignore
453+ convert : FieldEncoder = None ,
454+ has_dataclass : bool = False ,
455+ type_checking : TypeValidation = TypeValidation . WARNING ,
418456 ):
419457 self ._type_checking = type_checking
420458 self ._type = get_origin (_type ) or _type
@@ -433,6 +471,9 @@ def __init__(
433471 if self ._type is list :
434472 encoder = encode_dataclass_list
435473 decoder = decode_dataclass_list
474+ elif self ._type is set :
475+ encoder = encode_dataclass_list
476+ decoder = decode_dataclass_set
436477 elif self ._type is dict :
437478 encoder = encode_dataclass_dict
438479 decoder = decode_dataclass_dict
@@ -442,13 +483,11 @@ def __init__(
442483
443484 self ._convert = FieldEncoder (encoder , decoder )
444485
445- def __set_name__ (self , owner , name ):
446- if not hasattr (owner , "FIELD_NAMES" ):
447- owner .FIELD_NAMES = set ()
448-
449- if not hasattr (owner , "TYPE_CHECKING" ):
450- owner .TYPE_CHECKING = {}
486+ if not self ._convert and self ._type is set :
487+ self ._convert = FieldEncoder (encode_set , decode_set )
451488
489+ def __set_name__ (self , owner , name ):
490+ _setup_class_fields (owner )
452491 self ._name = name
453492 owner .TYPE_CHECKING [name ] = self ._type_checking
454493 owner .FIELD_NAMES .add (name )
@@ -461,37 +500,25 @@ def __get__(self, instance, owner):
461500 def __set__ (self , instance , value ):
462501 type_check = instance .TYPE_CHECKING [self ._name ]
463502 if (
464- type_check in {"error" , "warning" }
503+ type_check in {TypeValidation . STRICT , TypeValidation . WARNING }
465504 and value is not None
466505 and not isinstance (value , self ._type )
467506 ):
468- msg = f"{ self ._name } must be { self ._type } instead of { type (value )} "
469- if type_check == "error" :
507+ msg = f"{ self ._name } must be { self ._type } instead of { type (value )} for class { instance . __class__ } "
508+ if type_check == TypeValidation . STRICT :
470509 raise TypeError (msg )
471510
472511 logger .warning (msg )
473512
474- instance ._dirty_set .add (self ._name )
475- instance ._server_state [self ._name ] = value
476- instance ._on_dirty ()
513+ if instance ._server_state .get (self ._name ) != value :
514+ instance ._dirty_set .add (self ._name )
515+ instance ._server_state [self ._name ] = value
516+ instance ._on_dirty ()
477517
478518
479519class Sync (ServerOnly ):
480520 def __set_name__ (self , owner , name ):
481- if not hasattr (owner , "FIELD_NAMES" ):
482- owner .FIELD_NAMES = set ()
483-
484- if not hasattr (owner , "TYPE_CHECKING" ):
485- owner .TYPE_CHECKING = {}
486-
487- if not hasattr (owner , "CLIENT_NAMES" ):
488- owner .CLIENT_NAMES = set ()
489-
490- if not hasattr (owner , "ENCODERS" ):
491- owner .ENCODERS = {}
492-
493- if not hasattr (owner , "DATACLASS_NAMES" ):
494- owner .DATACLASS_NAMES = set ()
521+ _setup_class_fields (owner )
495522
496523 if self ._has_dataclass :
497524 owner .DATACLASS_NAMES .add (name )
@@ -507,18 +534,7 @@ def __set_name__(self, owner, name):
507534
508535class ClientOnly (ServerOnly ):
509536 def __set_name__ (self , owner , name ):
510- if not hasattr (owner , "FIELD_NAMES" ):
511- owner .FIELD_NAMES = set ()
512-
513- if not hasattr (owner , "TYPE_CHECKING" ):
514- owner .TYPE_CHECKING = {}
515-
516- if not hasattr (owner , "CLIENT_NAMES" ):
517- owner .CLIENT_NAMES = set ()
518-
519- if not hasattr (owner , "CLIENT_ONLY_NAMES" ):
520- owner .CLIENT_ONLY_NAMES = set ()
521-
537+ _setup_class_fields (owner )
522538 self ._name = name
523539 owner .TYPE_CHECKING [name ] = self ._type_checking
524540 owner .FIELD_NAMES .add (name )
0 commit comments