-
Notifications
You must be signed in to change notification settings - Fork 73
Description
Hi,
I have a use-case where I need to have a custom field defined in the base schema for a generic type. For example:
import typing
from dataclasses import dataclass
import marshmallow
import marshmallow_dataclass
_T = typing.TypeVar("_T")
class CustomType(typing.Generic[_T]):
def __init__(self, v: _T):
self._value = v
def value(self) -> _T:
return self._value
class CustomTypeField(marshmallow.fields.Field):
def _serialize(self, value: CustomType, attr, obj, **kwargs):
return {"value": value.value()}
def _deserialize(self, value, attr, data, **kwargs):
return CustomType(value["value"])In this example, I want any instance of CustomType to use the field CustomTypeField. The natural approach would be to set it in the TYPE_MAPPING of a base schema:
class BaseSchema(marshmallow.Schema):
TYPE_MAPPING = {CustomType: CustomTypeField}
@dataclass
class Foo:
x: CustomType
y: CustomType[int]
z: int
schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
obj = Foo(x=CustomType("aa"), y=CustomType(3), z=4)
schema.dump(obj)With the current marshmallow_dataclass version, this does not work. Indeed, field y has type CustomType[int], which is not in BaseSchema.TYPE_MAPPING. The following error is produced:
Traceback (most recent call last):
File "scratch_2.py", line 38, in main
schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "marshmallow_dataclass/__init__.py", line 462, in class_schema
return _internal_class_schema(clazz, base_schema)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "marshmallow_dataclass/__init__.py", line 552, in _internal_class_schema
attributes.update(
File "marshmallow_dataclass/__init__.py", line 555, in <genexpr>
_field_for_schema(
File "marshmallow_dataclass/__init__.py", line 890, in _field_for_schema
if issubclass(typ, Enum):
^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class
Proposed changes
My current workaround is to add a new dictionary in the base schema called GENERIC_TYPE_MAPPING that contains field overrides, discarding any generic argument.
class BaseSchema:
GENERIC_TYPE_MAPPING = {CustomType: AssetField}The lookup is implemented as follows:
def _field_by_generic_type(
typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]]
) -> Optional[Type[marshmallow.fields.Field]]:
origin = typing_inspect.get_origin(typ)
type_mapping = getattr(base_schema, "GENERIC_TYPE_MAPPING", {})
if origin is not None:
return type_mapping.get(origin)
else:
return type_mapping.get(typ)And the _field_for_schema function is modified to check this:
field = _field_by_generic_type(typ, base_schema)
if field:
return field(**metadata)My questions are:
- Do you think the use-case described is relevant for
marshmallow_dataclass? - Does my implementation seems fine ? I am not that familiar with marshmallow, so there is probably a better way to do it.
- Do you want me to open a PR to integrate this feature?
Thanks in advance,
Thomas.