Skip to content

Support for type-generic custom fieldsΒ #269

@thomashk0

Description

@thomashk0

Hi,

I have a use-case where I need to have a custom field defined in the base schema for a generic type. For example:

import typing
from dataclasses import dataclass

import marshmallow
import marshmallow_dataclass

_T = typing.TypeVar("_T")


class CustomType(typing.Generic[_T]):
    def __init__(self, v: _T):
        self._value = v

    def value(self) -> _T:
        return self._value


class CustomTypeField(marshmallow.fields.Field):
    def _serialize(self, value: CustomType, attr, obj, **kwargs):
        return {"value": value.value()}

    def _deserialize(self, value, attr, data, **kwargs):
        return CustomType(value["value"])

In this example, I want any instance of CustomType to use the field CustomTypeField. The natural approach would be to set it in the TYPE_MAPPING of a base schema:

class BaseSchema(marshmallow.Schema):
    TYPE_MAPPING = {CustomType: CustomTypeField}

@dataclass
class Foo:
    x: CustomType
    y: CustomType[int]
    z: int

schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
obj = Foo(x=CustomType("aa"), y=CustomType(3), z=4)
schema.dump(obj)

With the current marshmallow_dataclass version, this does not work. Indeed, field y has type CustomType[int], which is not in BaseSchema.TYPE_MAPPING. The following error is produced:

Traceback (most recent call last):
  File "scratch_2.py", line 38, in main
    schema = marshmallow_dataclass.class_schema(Foo, base_schema=BaseSchema)()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "marshmallow_dataclass/__init__.py", line 462, in class_schema
    return _internal_class_schema(clazz, base_schema)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "marshmallow_dataclass/__init__.py", line 552, in _internal_class_schema
    attributes.update(
  File "marshmallow_dataclass/__init__.py", line 555, in <genexpr>
    _field_for_schema(
  File "marshmallow_dataclass/__init__.py", line 890, in _field_for_schema
    if issubclass(typ, Enum):
       ^^^^^^^^^^^^^^^^^^^^^
TypeError: issubclass() arg 1 must be a class

Proposed changes

My current workaround is to add a new dictionary in the base schema called GENERIC_TYPE_MAPPING that contains field overrides, discarding any generic argument.

class BaseSchema:
    GENERIC_TYPE_MAPPING = {CustomType: AssetField}

The lookup is implemented as follows:

def _field_by_generic_type(
    typ: Union[type, Any], base_schema: Optional[Type[marshmallow.Schema]]
) -> Optional[Type[marshmallow.fields.Field]]:
    origin = typing_inspect.get_origin(typ)
    type_mapping = getattr(base_schema, "GENERIC_TYPE_MAPPING", {})
    if origin is not None:
        return type_mapping.get(origin)
    else:
        return type_mapping.get(typ)

And the _field_for_schema function is modified to check this:

field = _field_by_generic_type(typ, base_schema)
if field:
    return field(**metadata)

My questions are:

  • Do you think the use-case described is relevant for marshmallow_dataclass?
  • Does my implementation seems fine ? I am not that familiar with marshmallow, so there is probably a better way to do it.
  • Do you want me to open a PR to integrate this feature?

Thanks in advance,
Thomas.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions