Skip to content

Commit c41cf90

Browse files
authored
upath.registry: narrow get_upath_class types on protocol (#429)
* upath.registry: overload get_upath_class annotations to return correct subclass * typesafety: add tests for get_upath_class * typesafety: split test into py39 and py310+ due to Union type * tests: xfail when hitting github rate limit
1 parent dbdae66 commit c41cf90

File tree

4 files changed

+165
-4
lines changed

4 files changed

+165
-4
lines changed

typesafety/test_upath_types.yml

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,74 @@
1919
b: JoinablePathLike = PurePath()
2020
c: JoinablePathLike = Path()
2121
d: JoinablePathLike = UPath()
22+
23+
- case: get_upath_class_fsspec_protocols
24+
disable_cache: false
25+
parametrized:
26+
- cls_fqn: upath.implementations.cached.SimpleCachePath
27+
protocol: simplecache
28+
- cls_fqn: upath.implementations.cloud.S3Path
29+
protocol: s3
30+
- cls_fqn: upath.implementations.cloud.GCSPath
31+
protocol: gcs
32+
- cls_fqn: upath.implementations.cloud.AzurePath
33+
protocol: abfs
34+
- cls_fqn: upath.implementations.data.DataPath
35+
protocol: data
36+
- cls_fqn: upath.implementations.github.GitHubPath
37+
protocol: github
38+
- cls_fqn: upath.implementations.hdfs.HDFSPath
39+
protocol: hdfs
40+
- cls_fqn: upath.implementations.http.HTTPPath
41+
protocol: http
42+
- cls_fqn: upath.implementations.local.FilePath
43+
protocol: file
44+
- cls_fqn: upath.implementations.memory.MemoryPath
45+
protocol: memory
46+
- cls_fqn: upath.implementations.sftp.SFTPPath
47+
protocol: sftp
48+
- cls_fqn: upath.implementations.smb.SMBPath
49+
protocol: smb
50+
- cls_fqn: upath.implementations.webdav.WebdavPath
51+
protocol: webdav
52+
main: |
53+
from upath.registry import get_upath_class
54+
55+
path_cls = get_upath_class("{{ protocol }}")
56+
reveal_type(path_cls) # N: Revealed type is "type[{{ cls_fqn }}]"
57+
58+
- case: get_upath_class_pathlib_unix
59+
disable_cache: false
60+
skip: sys.platform == "win32"
61+
main: |
62+
from upath.registry import get_upath_class
63+
64+
path_cls = get_upath_class("")
65+
reveal_type(path_cls) # N: Revealed type is "type[upath.implementations.local.PosixUPath]"
66+
67+
- case: get_upath_class_pathlib_win
68+
disable_cache: false
69+
skip: sys.platform != "win32"
70+
main: |
71+
from upath.registry import get_upath_class
72+
73+
path_cls = get_upath_class("")
74+
reveal_type(path_cls) # N: Revealed type is "type[upath.implementations.local.WindowsUPath]"
75+
76+
- case: get_upath_class_unknown_protocol_py39
77+
disable_cache: false
78+
skip: sys.version_info >= (3, 10)
79+
main: |
80+
from upath.registry import get_upath_class
81+
82+
path_cls = get_upath_class("some-unknown-protocol")
83+
reveal_type(path_cls) # N: Revealed type is "Union[type[upath.core.UPath], None]"
84+
85+
- case: get_upath_class_unknown_protocol_py310plus
86+
disable_cache: false
87+
skip: sys.version_info < (3, 10)
88+
main: |
89+
from upath.registry import get_upath_class
90+
91+
path_cls = get_upath_class("some-unknown-protocol")
92+
reveal_type(path_cls) # N: Revealed type is "type[upath.core.UPath] | None"

upath/implementations/_experimental.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,17 @@
1010

1111
def __getattr__(name: str) -> type[UPath]:
1212
if name.startswith("_") and name.endswith("Path"):
13+
from upath import UPath
14+
1315
protocol = name[1:-4].lower()
14-
cls = get_upath_class(protocol, fallback=False)
15-
assert cls is not None
16+
cls = get_upath_class(protocol)
17+
if cls is None:
18+
raise RuntimeError(
19+
f"Could not find fsspec implementation for protocol {protocol!r}"
20+
)
21+
elif not issubclass(cls, UPath):
22+
raise RuntimeError(
23+
"UPath implementation not a subclass of upath.UPath, {cls!r}"
24+
)
1625
return cls
1726
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

upath/registry.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,27 @@
4848

4949
import upath
5050

51+
if TYPE_CHECKING:
52+
from typing import Literal
53+
from typing import overload
54+
55+
from upath.implementations.cached import SimpleCachePath as _SimpleCachePath
56+
from upath.implementations.cloud import AzurePath as _AzurePath
57+
from upath.implementations.cloud import GCSPath as _GCSPath
58+
from upath.implementations.cloud import S3Path as _S3Path
59+
from upath.implementations.data import DataPath as _DataPath
60+
from upath.implementations.github import GitHubPath as _GitHubPath
61+
from upath.implementations.hdfs import HDFSPath as _HDFSPath
62+
from upath.implementations.http import HTTPPath as _HTTPPath
63+
from upath.implementations.local import FilePath as _FilePath
64+
from upath.implementations.local import PosixUPath as _PosixUPath
65+
from upath.implementations.local import WindowsUPath as _WindowsUPath
66+
from upath.implementations.memory import MemoryPath as _MemoryPath
67+
from upath.implementations.sftp import SFTPPath as _SFTPPath
68+
from upath.implementations.smb import SMBPath as _SMBPath
69+
from upath.implementations.webdav import WebdavPath as _WebdavPath
70+
71+
5172
__all__ = [
5273
"get_upath_class",
5374
"available_implementations",
@@ -125,7 +146,7 @@ def __setitem__(self, item: str, value: type[upath.UPath] | str) -> None:
125146
f"expected UPath subclass or FQN-string, got: {type(value).__name__!r}"
126147
)
127148
if not item or item in self._m:
128-
get_upath_class.cache_clear()
149+
get_upath_class.cache_clear() # type: ignore[attr-defined]
129150
self._m[item] = value
130151

131152
def __delitem__(self, __v: str) -> None:
@@ -182,7 +203,57 @@ def register_implementation(
182203
_registry[protocol] = cls
183204

184205

185-
@lru_cache
206+
# --- get_upath_class type overloads ------------------------------------------
207+
208+
if TYPE_CHECKING: # noqa: C901
209+
210+
@overload
211+
def get_upath_class(protocol: Literal["simplecache"]) -> type[_SimpleCachePath]: ...
212+
@overload
213+
def get_upath_class(protocol: Literal["s3", "s3a"]) -> type[_S3Path]: ...
214+
@overload
215+
def get_upath_class(protocol: Literal["gcs", "gs"]) -> type[_GCSPath]: ...
216+
217+
@overload
218+
def get_upath_class(
219+
protocol: Literal["abfs", "abfss", "adl", "az"],
220+
) -> type[_AzurePath]: ...
221+
@overload
222+
def get_upath_class(protocol: Literal["data"]) -> type[_DataPath]: ...
223+
@overload
224+
def get_upath_class(protocol: Literal["github"]) -> type[_GitHubPath]: ...
225+
@overload
226+
def get_upath_class(protocol: Literal["hdfs"]) -> type[_HDFSPath]: ...
227+
@overload
228+
def get_upath_class(protocol: Literal["http", "https"]) -> type[_HTTPPath]: ...
229+
@overload
230+
def get_upath_class(protocol: Literal["file", "local"]) -> type[_FilePath]: ...
231+
@overload
232+
def get_upath_class(protocol: Literal["memory"]) -> type[_MemoryPath]: ...
233+
@overload
234+
def get_upath_class(protocol: Literal["sftp", "ssh"]) -> type[_SFTPPath]: ...
235+
@overload
236+
def get_upath_class(protocol: Literal["smb"]) -> type[_SMBPath]: ...
237+
@overload
238+
def get_upath_class(protocol: Literal["webdav"]) -> type[_WebdavPath]: ...
239+
240+
if sys.platform == "win32":
241+
242+
@overload
243+
def get_upath_class(protocol: Literal[""]) -> type[_WindowsUPath]: ...
244+
245+
else:
246+
247+
@overload
248+
def get_upath_class(protocol: Literal[""]) -> type[_PosixUPath]: ... # type: ignore[overload-overlap] # noqa: E501
249+
250+
@overload
251+
def get_upath_class(
252+
protocol: str, *, fallback: bool = ...
253+
) -> type[upath.UPath] | None: ...
254+
255+
256+
@lru_cache # type: ignore[misc] # see: https://github.com/python/typeshed/issues/11280
186257
def get_upath_class(
187258
protocol: str,
188259
*,

upath/tests/implementations/test_github.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,16 @@ def path(self):
3030
path = "github://ap--:universal_pathlib@test_data/data"
3131
self.path = UPath(path)
3232

33+
@pytest.fixture(autouse=True)
34+
def _xfail_on_rate_limit_errors(self):
35+
try:
36+
yield
37+
except Exception as e:
38+
if "rate limit exceeded" in str(e):
39+
pytest.xfail("GitHub API rate limit exceeded")
40+
else:
41+
raise
42+
3343
def test_is_GitHubPath(self):
3444
"""
3545
Test that the path is a GitHubPath instance.

0 commit comments

Comments
 (0)