Skip to content

Commit e2294f5

Browse files
authored
Fix type compatibility with marshmallow v3 and v4 (#659)
1 parent 2631978 commit e2294f5

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

src/marshmallow_sqlalchemy/fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,4 @@ def get_primary_keys(model: type[DeclarativeMeta]) -> list[MapperProperty]:
160160

161161

162162
def ensure_list(value: Any) -> list:
163-
return value if is_iterable_but_not_string(value) else [value]
163+
return list(value) if is_iterable_but_not_string(value) else [value]

src/marshmallow_sqlalchemy/load_instance_mixin.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from __future__ import annotations
1010

11-
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
11+
import importlib.metadata
12+
from collections.abc import Iterable, Mapping, Sequence
13+
from typing import Any, Generic, TypeVar, Union, cast
1214

1315
import marshmallow as ma
1416
from sqlalchemy.ext.declarative import DeclarativeMeta
@@ -17,12 +19,18 @@
1719

1820
from .fields import get_primary_keys
1921

20-
if TYPE_CHECKING:
21-
from collections.abc import Iterable, Mapping
22-
22+
_LoadDataV3 = Union[Mapping[str, Any], Iterable[Mapping[str, Any]]]
23+
_LoadDataV4 = Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]
24+
_LoadDataT = TypeVar("_LoadDataT", _LoadDataV3, _LoadDataV4)
2325
_ModelType = TypeVar("_ModelType", bound=DeclarativeMeta)
2426

2527

28+
def _cast_data(data):
29+
if int(importlib.metadata.version("marshmallow")[0]) >= 4:
30+
return cast(_LoadDataV4, data)
31+
return cast(_LoadDataV3, data)
32+
33+
2634
class LoadInstanceMixin:
2735
class Opts:
2836
model: type[DeclarativeMeta] | None
@@ -114,7 +122,7 @@ def make_instance(self, data, **kwargs) -> _ModelType:
114122

115123
def load(
116124
self,
117-
data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
125+
data: _LoadDataT,
118126
*,
119127
session: Session | None = None,
120128
instance: _ModelType | None = None,
@@ -136,13 +144,13 @@ def load(
136144
raise ValueError("Deserialization requires a session")
137145
self.instance = instance or self.instance
138146
try:
139-
return cast(ma.Schema, super()).load(data, **kwargs)
147+
return cast(ma.Schema, super()).load(_cast_data(data), **kwargs)
140148
finally:
141149
self.instance = None
142150

143151
def validate(
144152
self,
145-
data: Mapping[str, Any] | Iterable[Mapping[str, Any]],
153+
data: _LoadDataT,
146154
*,
147155
session: Session | None = None,
148156
**kwargs,
@@ -151,7 +159,7 @@ def validate(
151159
self._session = session or self._session
152160
if not (self.transient or self.session):
153161
raise ValueError("Validation requires a session")
154-
return cast(ma.Schema, super()).validate(data, **kwargs)
162+
return cast(ma.Schema, super()).validate(_cast_data(data), **kwargs)
155163

156164
def _split_model_kwargs_association(self, data):
157165
"""Split serialized attrs to ensure association proxies are passed separately.

0 commit comments

Comments
 (0)