Skip to content

Commit 4e20c1f

Browse files
committed
fix(typing): add flexibility in type checking
1 parent aa46d41 commit 4e20c1f

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

examples/type_check.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from trame_dataclass.v2 import (
2+
ServerOnly,
3+
StateDataModel,
4+
)
5+
6+
7+
class MixFields(StateDataModel):
8+
integer = ServerOnly(int, type_checking="ignore")
9+
number = ServerOnly(float)
10+
11+
12+
fields = MixFields()
13+
fields.integer = 1
14+
fields.number = 1
15+
fields.integer = 1.2
16+
fields.number = 1.2

src/trame_dataclass/v2.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,15 @@ def get_instance(instance_id: str):
408408

409409

410410
class ServerOnly:
411-
def __init__(self, _type, default=None, convert=None, has_dataclass=False):
411+
def __init__(
412+
self,
413+
_type,
414+
default=None,
415+
convert=None,
416+
has_dataclass=False,
417+
type_checking="warning", # error, warning, ignore
418+
):
419+
self._type_checking = type_checking
412420
self._type = get_origin(_type) or _type
413421
if self._type in (Union, types.UnionType):
414422
self._type = get_args(_type)[0]
@@ -438,18 +446,30 @@ def __set_name__(self, owner, name):
438446
if not hasattr(owner, "FIELD_NAMES"):
439447
owner.FIELD_NAMES = set()
440448

449+
if not hasattr(owner, "TYPE_CHECKING"):
450+
owner.TYPE_CHECKING = {}
451+
441452
self._name = name
442453
owner.FIELD_NAMES.add(name)
454+
owner.TYPE_CHECKING[name] = self._type_checking
443455

444456
def __get__(self, instance, owner):
445457
if self._name not in instance._server_state:
446458
instance._server_state[self._name] = self._default
447459
return instance._server_state.get(self._name)
448460

449461
def __set__(self, instance, value):
450-
if value is not None and not isinstance(value, self._type):
451-
msg = f"{self._name} must be {self._type}"
452-
raise TypeError(msg)
462+
type_check = instance.TYPE_CHECKING[self._name]
463+
if (
464+
type_check in {"error", "warning"}
465+
and value is not None
466+
and not isinstance(value, self._type)
467+
):
468+
msg = f"{self._name} must be {self._type} instead of {type(value)}"
469+
if type_check == "error":
470+
raise TypeError(msg)
471+
472+
logger.warning(msg)
453473

454474
instance._dirty_set.add(self._name)
455475
instance._server_state[self._name] = value

0 commit comments

Comments
 (0)