diff --git a/pyproject.toml b/pyproject.toml index ec243d78..6a5d6cac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ tests = [ "pytest-mock >=3.12.0", "pylint >=2.17.4", "mypy >=1.10.0", - "pydantic >=2", + "pydantic >=2", # <3 required for testing pydantic v1 support, not for actual use "pytest-mypy-plugins >=3.1.2", "packaging", ] diff --git a/upath/core.py b/upath/core.py index 4b120614..ef441ebc 100644 --- a/upath/core.py +++ b/upath/core.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import pathlib import sys import warnings from abc import ABCMeta @@ -14,8 +15,10 @@ from typing import TYPE_CHECKING from typing import Any from typing import BinaryIO +from typing import Callable from typing import Literal from typing import TextIO +from typing import TypedDict from typing import overload from urllib.parse import SplitResult from urllib.parse import urlsplit @@ -40,8 +43,10 @@ if TYPE_CHECKING: if sys.version_info >= (3, 11): + from typing import NotRequired from typing import Self else: + from typing_extensions import NotRequired from typing_extensions import Self from pydantic import GetCoreSchemaHandler @@ -107,6 +112,14 @@ def __getitem__(cls, key): return cls +class SerializedUPath(TypedDict): + """Serialized format for a UPath object""" + + path: str + protocol: NotRequired[str] + storage_options: NotRequired[dict[str, Any]] + + class _UPathMixin(metaclass=_UPathMeta): __slots__ = () @@ -179,6 +192,13 @@ def path(self) -> str: """The path that a fsspec filesystem can use.""" return self.parser.strip_protocol(self.__str__()) + def to_dict(self) -> SerializedUPath: + return { + "path": self.path, + "protocol": self.protocol, + "storage_options": dict(self.storage_options), + } + def joinuri(self, uri: JoinablePathLike) -> UPath: """Join with urljoin behavior for UPath instances""" # short circuit if the new uri uses a different protocol @@ -946,9 +966,7 @@ def __get_pydantic_core_schema__( deserialization_schema = core_schema.chain_schema( [ - core_schema.no_info_plain_validator_function( - lambda v: {"path": v} if isinstance(v, str) else v, - ), + core_schema.no_info_plain_validator_function(cls._to_serialized_format), core_schema.typed_dict_schema( { "path": core_schema.typed_dict_field( @@ -973,13 +991,7 @@ def __get_pydantic_core_schema__( }, extra_behavior="forbid", ), - core_schema.no_info_plain_validator_function( - lambda dct: cls( - dct.pop("path"), - protocol=dct.pop("protocol"), - **dct["storage_options"], - ) - ), + core_schema.no_info_plain_validator_function(cls._validate), ] ) @@ -998,3 +1010,39 @@ def __get_pydantic_core_schema__( ), serialization=serialization_schema, ) + + @classmethod + def __get_validators__(cls) -> Iterator[Callable]: + yield cls._validate + + @staticmethod + def _to_serialized_format( + v: str | pathlib.Path | _UPathMixin | dict[str, Any], + ) -> SerializedUPath: + if isinstance(v, _UPathMixin): + return v.to_dict() + if isinstance(v, dict): + return { + "path": v["path"], + "protocol": v.get("protocol", ""), + "storage_options": v.get("storage_options", {}), + } + if isinstance(v, pathlib.Path): + return {"path": v.as_posix(), "protocol": ""} + if isinstance(v, str): + return { + "path": v, + } + raise TypeError(f"Invalid path: {v!r}") + + @classmethod + def _validate(cls, v: Any) -> UPath: + if not isinstance(v, UPath): + v = cls._to_serialized_format(v) + + return cls( + v["path"], + protocol=v.get("protocol"), + **v.get("storage_options", {}), # type: ignore[arg-type] + ) + return v diff --git a/upath/tests/test_pydantic.py b/upath/tests/test_pydantic.py index 383ded5c..e7928a22 100644 --- a/upath/tests/test_pydantic.py +++ b/upath/tests/test_pydantic.py @@ -2,6 +2,7 @@ from os.path import abspath import pydantic +import pydantic.v1 as pydantic_v1 import pydantic_core import pytest from fsspec.implementations.http import get_client @@ -9,6 +10,31 @@ from upath import UPath +@pytest.fixture(params=["v1", "v2"]) +def pydantic_version(request): + return request.param + + +@pytest.fixture(params=["json", "python"]) +def source(request): + return request.param + + +@pytest.fixture +def parser(pydantic_version, source): + if pydantic_version == "v1": + if source == "json": + return lambda x: pydantic_v1.tools.parse_raw_as(UPath, x) + else: + return lambda x: pydantic_v1.tools.parse_obj_as(UPath, x) + else: + ta = pydantic.TypeAdapter(UPath) + if source == "json": + return ta.validate_json + else: + return ta.validate_python + + @pytest.mark.parametrize( "path", [ @@ -19,15 +45,13 @@ "https://www.example.com", ], ) -@pytest.mark.parametrize("source", ["json", "python"]) -def test_validate_from_str(path, source): +def test_validate_from_str(path, source, parser): expected = UPath(path) - ta = pydantic.TypeAdapter(UPath) if source == "json": - actual = ta.validate_json(json.dumps(path)) - else: # source == "python" - actual = ta.validate_python(path) + path = json.dumps(path) + + actual = parser(path) assert abspath(actual.path) == abspath(expected.path) assert actual.protocol == expected.protocol @@ -43,13 +67,13 @@ def test_validate_from_str(path, source): } ], ) -@pytest.mark.parametrize("source", ["json", "python"]) -def test_validate_from_dict(dct, source): - ta = pydantic.TypeAdapter(UPath) +def test_validate_from_dict(dct, source, parser): if source == "json": - output = ta.validate_json(json.dumps(dct)) - else: # source == "python" - output = ta.validate_python(dct) + data = json.dumps(dct) + else: + data = dct + + output = parser(data) assert abspath(output.path) == abspath(dct["path"]) assert output.protocol == dct["protocol"] @@ -66,10 +90,13 @@ def test_validate_from_dict(dct, source): "https://www.example.com", ], ) -def test_validate_from_instance(path): +def test_validate_from_instance(path, pydantic_version): input = UPath(path) - output = pydantic.TypeAdapter(UPath).validate_python(input) + if pydantic_version == "v1": + output = pydantic_v1.tools.parse_obj_as(UPath, input) + else: + output = pydantic.TypeAdapter(UPath).validate_python(input) assert output is input @@ -88,26 +115,38 @@ def test_validate_from_instance(path): ], ) @pytest.mark.parametrize("mode", ["json", "python"]) -def test_dump(args, kwargs, mode): +def test_dump(args, kwargs, mode, pydantic_version): u = UPath(*args, **kwargs) - output = pydantic.TypeAdapter(UPath).dump_python(u, mode=mode) + if pydantic_version == "v1": + output = u.to_dict() + else: + output = pydantic.TypeAdapter(UPath).dump_python(u, mode=mode) assert output["path"] == u.path assert output["protocol"] == u.protocol assert output["storage_options"] == u.storage_options -def test_dump_non_serializable_python(): - output = pydantic.TypeAdapter(UPath).dump_python( - UPath("https://www.example.com", get_client=get_client), mode="python" - ) +def test_dump_non_serializable_python(pydantic_version): + upath = UPath("https://www.example.com", get_client=get_client) + + if pydantic_version == "v1": + output = upath.to_dict() + else: + output = pydantic.TypeAdapter(UPath).dump_python(upath, mode="python") assert output["storage_options"]["get_client"] is get_client -def test_dump_non_serializable_json(): - with pytest.raises(pydantic_core.PydanticSerializationError, match="unknown type"): - pydantic.TypeAdapter(UPath).dump_python( - UPath("https://www.example.com", get_client=get_client), mode="json" - ) +def test_dump_non_serializable_json(pydantic_version): + upath = UPath("https://www.example.com", get_client=get_client) + + if pydantic_version == "v1": + with pytest.raises(TypeError, match="not JSON serializable"): + json.dumps(upath.to_dict()) + else: + with pytest.raises( + pydantic_core.PydanticSerializationError, match="unknown type" + ): + pydantic.TypeAdapter(UPath).dump_python(upath, mode="json")