Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions fast_depends/msgspec/serializer.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,62 @@
import inspect
import re
from collections.abc import Iterator, Sequence
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from typing import Any
from typing import Any, TypeVar

import msgspec

from fast_depends.exceptions import ValidationError
from fast_depends.library.serializer import OptionItem, Serializer, SerializerProto

T = TypeVar("T")


class MsgSpecSerializer(SerializerProto):
__slots__ = ("use_fastdepends_errors",)
__slots__ = ("use_fastdepends_errors", "dec_hook")

def __init__(
self,
use_fastdepends_errors: bool = True,
dec_hook: Callable[[type[T], Any], T] | None = None,
) -> None:
self.use_fastdepends_errors = use_fastdepends_errors
self.dec_hook = dec_hook

def __call__(
self,
*,
name: str,
options: list[OptionItem],
response_type: Any,
) -> "Serializer":
) -> "_MsgSpecSerializer":
if self.use_fastdepends_errors:
if response_type is not inspect.Parameter.empty:
return _MsgSpecWrappedSerializerWithResponse(
name=name,
options=options,
response_type=response_type,
dec_hook=self.dec_hook,
)

return _MsgSpecWrappedSerializer(
name=name,
options=options,
dec_hook=self.dec_hook,
)

if response_type is not inspect.Parameter.empty:
return _MsgSpecSerializerWithResponse(
name=name,
options=options,
response_type=response_type,
dec_hook=self.dec_hook,
)

return _MsgSpecSerializer(
name=name,
options=options,
dec_hook=self.dec_hook,
)

@staticmethod
Expand All @@ -66,6 +74,7 @@ class _MsgSpecSerializer(Serializer):
"name",
"options",
"response_option",
"dec_hook",
)

def __init__(
Expand All @@ -74,15 +83,11 @@ def __init__(
name: str,
options: list[OptionItem],
response_type: Any = None,
dec_hook: Callable[[type[T], Any], T] | None = None,
):
model_options: list[str | tuple[str, type] | tuple[str, type, Any]] = []
aliases = {}
for i in options:
if isinstance(
msgspec.inspect.type_info(i.field_type), msgspec.inspect.CustomType
):
continue

default_value = i.default_value

if isinstance(default_value, msgspec._core.Field) and default_value.name:
Expand All @@ -108,6 +113,7 @@ def __init__(

self.aliases = aliases
self.model = msgspec.defstruct(name, model_options, kw_only=True)
self.dec_hook = dec_hook
super().__init__(name=name, options=options, response_type=response_type)

def get_aliases(self) -> tuple[str, ...]:
Expand All @@ -119,6 +125,7 @@ def __call__(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:
type=self.model,
strict=False,
str_keys=True,
dec_hook=self.dec_hook,
)

return {
Expand All @@ -134,12 +141,23 @@ def __init__(
name: str,
options: list[OptionItem],
response_type: Any,
dec_hook: Callable[[type[T], Any], T] | None = None,
):
super().__init__(name=name, options=options, response_type=response_type)
super().__init__(
name=name,
options=options,
response_type=response_type,
dec_hook=dec_hook,
)
self.response_type = response_type

def response(self, value: Any) -> Any:
return msgspec.convert(value, type=self.response_type, strict=False)
return msgspec.convert(
value,
type=self.response_type,
strict=False,
dec_hook=self.dec_hook,
)


class _MsgSpecWrappedSerializer(_MsgSpecSerializer):
Expand All @@ -150,6 +168,7 @@ def __call__(self, call_kwargs: dict[str, Any]) -> dict[str, Any]:
type=self.model,
strict=False,
str_keys=True,
dec_hook=self.dec_hook,
)

return {
Expand Down Expand Up @@ -182,10 +201,21 @@ def __init__(
name: str,
options: list[OptionItem],
response_type: Any,
dec_hook: Callable[[type[T], Any], T] | None = None,
):
super().__init__(name=name, options=options, response_type=response_type)
super().__init__(
name=name,
options=options,
response_type=response_type,
dec_hook=dec_hook,
)
self.response_type = response_type

def response(self, value: Any) -> Any:
with self._try_msgspec(value, self.response_option, ("return",)):
return msgspec.convert(value, type=self.response_type, strict=False)
return msgspec.convert(
value,
type=self.response_type,
strict=False,
dec_hook=self.dec_hook,
)
2 changes: 1 addition & 1 deletion fast_depends/pydantic/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class PydanticSerializer(SerializerProto):
__slots__ = (
"pydantic_config",
"config",
"use_fastdepends_errors",
)

Expand Down
4 changes: 2 additions & 2 deletions tests/marks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
pydantic = pytest.mark.skipif(not HAS_PYDANTIC, reason="requires Pydantic") # noqa: N816

pydanticV1 = pytest.mark.skipif(
not HAS_PYDANTIC or PYDANTIC_V2, reason="requires PydanticV2"
not HAS_PYDANTIC or PYDANTIC_V2, reason="requires PydanticV1"
) # noqa: N816

pydanticV2 = pytest.mark.skipif(
not HAS_PYDANTIC or not PYDANTIC_V2, reason="requires PydanticV1"
not HAS_PYDANTIC or not PYDANTIC_V2, reason="requires PydanticV2"
) # noqa: N816
50 changes: 50 additions & 0 deletions tests/serializers/msgspec/test_custom_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any, TypeVar

import pytest

from fast_depends import Depends, Provider, inject
from fast_depends.exceptions import ValidationError
from fast_depends.msgspec import MsgSpecSerializer

T = TypeVar("T")


class CustomType:
def __init__(self, value):
self.value = value


def msgspec_custom_type_decoder(t: type[T], obj: Any) -> T:
if not isinstance(obj, t):
return t(obj)
return obj


def dep(a: CustomType) -> str:
return a.value


@inject(
serializer_cls=MsgSpecSerializer(use_fastdepends_errors=True),
dependency_provider=Provider(),
)
def custom_type_without_decoder(a: CustomType = Depends(dep)): ...


@inject(
serializer_cls=MsgSpecSerializer(
use_fastdepends_errors=True,
dec_hook=msgspec_custom_type_decoder,
),
dependency_provider=Provider(),
)
def custom_type_with_decoder(a: CustomType = Depends(dep)) -> str:
assert isinstance(a, CustomType)
return a.value


def test_custom_type_cast():
custom_type_with_decoder("123")

with pytest.raises(ValidationError):
custom_type_without_decoder("123")