Skip to content

Commit 8a2746c

Browse files
authored
add optional support for pydantic (fsspec#395)
* add optional support for pydantic * fix type checking * try abspath for windows * more mypy ignores * Revert "try abspath for windows" This reverts commit c0090c5. * try fix windows * refactor imports * parametrize tests
1 parent eb08231 commit 8a2746c

File tree

3 files changed

+194
-0
lines changed

3 files changed

+194
-0
lines changed

pyproject.toml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ tests = [
4141
"pytest-mock >=3.12.0",
4242
"pylint >=2.17.4",
4343
"mypy >=1.10.0",
44+
"pydantic >=2",
4445
"pytest-mypy-plugins >=3.1.2",
4546
"packaging",
4647
]
@@ -168,6 +169,22 @@ ignore_missing_imports = true
168169
module = "smbprotocol.*"
169170
ignore_missing_imports = true
170171

172+
[[tool.mypy.overrides]]
173+
module = "pydantic.*"
174+
ignore_errors = true
175+
176+
[[tool.mypy.overrides]]
177+
module = "pydantic_core.*"
178+
ignore_errors = true
179+
180+
[[tool.mypy.overrides]]
181+
module = "typing_inspection.*"
182+
ignore_errors = true
183+
184+
[[tool.mypy.overrides]]
185+
module = "annotated_types.*"
186+
ignore_errors = true
187+
171188
[tool.pylint.format]
172189
max-line-length = 88
173190

upath/core.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444
else:
4545
from typing_extensions import Self
4646

47+
from pydantic import GetCoreSchemaHandler
48+
from pydantic_core.core_schema import CoreSchema
49+
4750

4851
__all__ = ["UPath"]
4952

@@ -934,3 +937,64 @@ def match(self, pattern: str) -> bool:
934937
if not pattern:
935938
raise ValueError("pattern cannot be empty")
936939
return self.full_match(pattern.replace("**", "*"))
940+
941+
@classmethod
942+
def __get_pydantic_core_schema__(
943+
cls, _source_type: Any, _handler: GetCoreSchemaHandler
944+
) -> CoreSchema:
945+
from pydantic_core import core_schema
946+
947+
deserialization_schema = core_schema.chain_schema(
948+
[
949+
core_schema.no_info_plain_validator_function(
950+
lambda v: {"path": v} if isinstance(v, str) else v,
951+
),
952+
core_schema.typed_dict_schema(
953+
{
954+
"path": core_schema.typed_dict_field(
955+
core_schema.str_schema(), required=True
956+
),
957+
"protocol": core_schema.typed_dict_field(
958+
core_schema.with_default_schema(
959+
core_schema.str_schema(), default=""
960+
),
961+
required=False,
962+
),
963+
"storage_options": core_schema.typed_dict_field(
964+
core_schema.with_default_schema(
965+
core_schema.dict_schema(
966+
core_schema.str_schema(),
967+
core_schema.any_schema(),
968+
),
969+
default_factory=dict,
970+
),
971+
required=False,
972+
),
973+
},
974+
extra_behavior="forbid",
975+
),
976+
core_schema.no_info_plain_validator_function(
977+
lambda dct: cls(
978+
dct.pop("path"),
979+
protocol=dct.pop("protocol"),
980+
**dct["storage_options"],
981+
)
982+
),
983+
]
984+
)
985+
986+
serialization_schema = core_schema.plain_serializer_function_ser_schema(
987+
lambda u: {
988+
"path": u.path,
989+
"protocol": u.protocol,
990+
"storage_options": dict(u.storage_options),
991+
}
992+
)
993+
994+
return core_schema.json_or_python_schema(
995+
json_schema=deserialization_schema,
996+
python_schema=core_schema.union_schema(
997+
[core_schema.is_instance_schema(UPath), deserialization_schema]
998+
),
999+
serialization=serialization_schema,
1000+
)

upath/tests/test_pydantic.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import json
2+
from os.path import abspath
3+
4+
import pydantic
5+
import pydantic_core
6+
import pytest
7+
from fsspec.implementations.http import get_client
8+
9+
from upath import UPath
10+
11+
12+
@pytest.mark.parametrize(
13+
"path",
14+
[
15+
"/abc",
16+
"file:///abc",
17+
"memory://abc",
18+
"s3://bucket/key",
19+
"https://www.example.com",
20+
],
21+
)
22+
@pytest.mark.parametrize("source", ["json", "python"])
23+
def test_validate_from_str(path, source):
24+
expected = UPath(path)
25+
26+
ta = pydantic.TypeAdapter(UPath)
27+
if source == "json":
28+
actual = ta.validate_json(json.dumps(path))
29+
else: # source == "python"
30+
actual = ta.validate_python(path)
31+
32+
assert abspath(actual.path) == abspath(expected.path)
33+
assert actual.protocol == expected.protocol
34+
35+
36+
@pytest.mark.parametrize(
37+
"dct",
38+
[
39+
{
40+
"path": "/my/path",
41+
"protocol": "file",
42+
"storage_options": {"foo": "bar", "baz": 3},
43+
}
44+
],
45+
)
46+
@pytest.mark.parametrize("source", ["json", "python"])
47+
def test_validate_from_dict(dct, source):
48+
ta = pydantic.TypeAdapter(UPath)
49+
if source == "json":
50+
output = ta.validate_json(json.dumps(dct))
51+
else: # source == "python"
52+
output = ta.validate_python(dct)
53+
54+
assert abspath(output.path) == abspath(dct["path"])
55+
assert output.protocol == dct["protocol"]
56+
assert output.storage_options == dct["storage_options"]
57+
58+
59+
@pytest.mark.parametrize(
60+
"path",
61+
[
62+
"/abc",
63+
"file:///abc",
64+
"memory://abc",
65+
"s3://bucket/key",
66+
"https://www.example.com",
67+
],
68+
)
69+
def test_validate_from_instance(path):
70+
input = UPath(path)
71+
72+
output = pydantic.TypeAdapter(UPath).validate_python(input)
73+
74+
assert output is input
75+
76+
77+
@pytest.mark.parametrize(
78+
("args", "kwargs"),
79+
[
80+
(
81+
("/my/path",),
82+
{
83+
"protocol": "file",
84+
"foo": "bar",
85+
"baz": 3,
86+
},
87+
)
88+
],
89+
)
90+
@pytest.mark.parametrize("mode", ["json", "python"])
91+
def test_dump(args, kwargs, mode):
92+
u = UPath(*args, **kwargs)
93+
94+
output = pydantic.TypeAdapter(UPath).dump_python(u, mode=mode)
95+
96+
assert output["path"] == u.path
97+
assert output["protocol"] == u.protocol
98+
assert output["storage_options"] == u.storage_options
99+
100+
101+
def test_dump_non_serializable_python():
102+
output = pydantic.TypeAdapter(UPath).dump_python(
103+
UPath("https://www.example.com", get_client=get_client), mode="python"
104+
)
105+
106+
assert output["storage_options"]["get_client"] is get_client
107+
108+
109+
def test_dump_non_serializable_json():
110+
with pytest.raises(pydantic_core.PydanticSerializationError, match="unknown type"):
111+
pydantic.TypeAdapter(UPath).dump_python(
112+
UPath("https://www.example.com", get_client=get_client), mode="json"
113+
)

0 commit comments

Comments
 (0)