Skip to content

Commit e44a8c7

Browse files
authored
Raise if using subclass directly with wrong protocol (fsspec#541)
* add failing test for direct subclass init * adjust subclass tests: require registering * raise error when instantiating subclass with incorrect protocol * upath.registry: add _get_implementation_protocols * upath.core: fix protocol handling when instantiating subclasses directly * tests: add protocol incompatibility tests * upath._protocol: fix incompatible protocol for partially loaded impls and fallbacks * upath: fix typing
1 parent 640ca44 commit e44a8c7

File tree

5 files changed

+231
-38
lines changed

5 files changed

+231
-38
lines changed

upath/_protocol.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,24 @@ def _fsspec_protocol_equals(p0: str, p1: str) -> bool:
5555
except KeyError:
5656
raise ValueError(f"Protocol not known: {p1!r}")
5757

58+
if o0 == o1:
59+
return True
60+
61+
if isinstance(o0, dict):
62+
o0 = o0.get("class")
63+
elif isinstance(o0, type):
64+
if o0.__module__:
65+
o0 = o0.__module__ + "." + o0.__name__
66+
else:
67+
o0 = o0.__name__
68+
if isinstance(o1, dict):
69+
o1 = o1.get("class")
70+
elif isinstance(o1, type):
71+
if o1.__module__:
72+
o1 = o1.__module__ + "." + o1.__name__
73+
else:
74+
o1 = o1.__name__
75+
5876
return o0 == o1
5977

6078

upath/core.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from upath._protocol import compatible_protocol
3939
from upath._protocol import get_upath_protocol
4040
from upath._stat import UPathStatResult
41+
from upath.registry import _get_implementation_protocols
42+
from upath.registry import available_implementations
4143
from upath.registry import get_upath_class
4244
from upath.types import UNSET_DEFAULT
4345
from upath.types import JoinablePathLike
@@ -414,7 +416,7 @@ def _fs_factory(
414416

415417
_protocol_dispatch: bool | None = None
416418

417-
def __new__(
419+
def __new__( # noqa C901
418420
cls,
419421
*args: JoinablePathLike,
420422
protocol: str | None = None,
@@ -445,6 +447,27 @@ def __new__(
445447
if "incompatible with" in str(e):
446448
raise _IncompatibleProtocolError(str(e)) from e
447449
raise
450+
451+
# subclasses should default to their own protocol
452+
if protocol is None and cls is not UPath:
453+
impl_protocols = _get_implementation_protocols(cls)
454+
if not pth_protocol and impl_protocols:
455+
pth_protocol = impl_protocols[0]
456+
elif pth_protocol and pth_protocol not in impl_protocols:
457+
msg_protocol = pth_protocol
458+
if not pth_protocol:
459+
msg_protocol = "'' (empty string)"
460+
msg = (
461+
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s}"
462+
f" which is incompatible with {cls.__name__}."
463+
)
464+
if not pth_protocol or pth_protocol not in available_implementations():
465+
msg += (
466+
" Did you forget to register the subclass for this protocol"
467+
" with upath.registry.register_implementation()?"
468+
)
469+
raise _IncompatibleProtocolError(msg)
470+
448471
# determine which UPath subclass to dispatch to
449472
upath_cls: type[UPath] | None
450473
if cls._protocol_dispatch or cls._protocol_dispatch is None:
@@ -478,26 +501,24 @@ def __new__(
478501
raise RuntimeError("UPath.__new__ expected cls to be subclass of UPath")
479502

480503
else:
481-
msg_protocol = repr(pth_protocol)
504+
msg_protocol = pth_protocol
482505
if not pth_protocol:
483-
msg_protocol += " (empty string)"
506+
msg_protocol = "'' (empty string)"
484507
msg = (
485-
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s} and"
486-
f" returns a {upath_cls.__name__} instance that isn't a direct"
487-
f" subclass of {cls.__name__}. This will raise an exception in"
488-
" future universal_pathlib versions. To prevent the issue, use"
489-
" UPath(...) to create instances of unrelated protocols or you"
490-
f" can instead derive your subclass {cls.__name__!s}(...) from"
491-
f" {upath_cls.__name__} or alternatively override behavior via"
492-
f" registering the {cls.__name__} implementation with protocol"
493-
f" {msg_protocol!s} replacing the default implementation."
494-
)
495-
warnings.warn(
496-
msg,
497-
DeprecationWarning,
498-
stacklevel=2,
508+
f"{cls.__name__!s}(...) detected protocol {msg_protocol!s}"
509+
f" which is incompatible with {cls.__name__}."
499510
)
500-
upath_cls = cls
511+
if (
512+
# find a better way
513+
(not pth_protocol and cls.__name__ not in ["CloudPath", "LocalPath"])
514+
or pth_protocol
515+
and pth_protocol not in available_implementations()
516+
):
517+
msg += (
518+
" Did you forget to register the subclass for this protocol"
519+
" with upath.registry.register_implementation()?"
520+
)
521+
raise _IncompatibleProtocolError(msg)
501522

502523
return object.__new__(upath_cls)
503524

@@ -530,7 +551,6 @@ def __init__(
530551
Additional storage options for the path.
531552
532553
"""
533-
534554
# todo: avoid duplicating this call from __new__
535555
protocol = get_upath_protocol(
536556
args[0] if args else "",
@@ -549,6 +569,12 @@ def __init__(
549569
if not compatible_protocol(protocol, *args):
550570
raise ValueError("can't combine incompatible UPath protocols")
551571

572+
# subclasses should default to their own protocol
573+
if not protocol:
574+
impl_protocols = _get_implementation_protocols(type(self))
575+
if impl_protocols:
576+
protocol = impl_protocols[0]
577+
552578
if args:
553579
args0 = args[0]
554580
if isinstance(args0, UPath):

upath/registry.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def __setitem__(self, item: str, value: type[upath.UPath] | str) -> None:
155155
)
156156
if not item or item in self._m:
157157
get_upath_class.cache_clear() # type: ignore[attr-defined]
158+
_get_implementation_protocols.cache_clear() # type: ignore[attr-defined]
158159
self._m[item] = value
159160

160161
def __delitem__(self, __v: str) -> None:
@@ -211,6 +212,32 @@ def register_implementation(
211212
_registry[protocol] = cls
212213

213214

215+
@lru_cache # type: ignore[misc]
216+
def _get_implementation_protocols(cls: type[upath.UPath]) -> list[str]:
217+
"""return protocols registered for a given UPath class without triggering imports"""
218+
if not issubclass(cls, upath.UPath):
219+
raise ValueError(f"{cls!r} is not a UPath subclass")
220+
if cls.__module__ == "upath.implementations._experimental":
221+
# experimental fallback implementations have no registry entry
222+
return [cls.__name__[1:-4].lower()]
223+
loaded = (
224+
p
225+
for p, c in _registry._m.maps[0].items() # type: ignore[attr-defined]
226+
if c is cls
227+
)
228+
known = (
229+
p
230+
for p, fqn in _registry.known_implementations.items()
231+
if fqn == f"{cls.__module__}.{cls.__name__}"
232+
)
233+
eps = (
234+
p
235+
for p, ep in _registry._entries.items()
236+
if ep.module == cls.__module__ and ep.attr == cls.__name__
237+
)
238+
return list(dict.fromkeys((*loaded, *known, *eps)))
239+
240+
214241
# --- get_upath_class type overloads ------------------------------------------
215242

216243
if TYPE_CHECKING: # noqa: C901

upath/tests/test_core.py

Lines changed: 110 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from upath import UPath
1212
from upath.implementations.cloud import GCSPath
1313
from upath.implementations.cloud import S3Path
14+
from upath.registry import get_upath_class
15+
from upath.registry import register_implementation
1416
from upath.types import ReadablePath
1517
from upath.types import WritablePath
1618

@@ -112,12 +114,35 @@ def test_subclass(local_testdir):
112114
class MyPath(UPath):
113115
pass
114116

115-
with pytest.warns(
116-
DeprecationWarning, match=r"MyPath\(...\) detected protocol '' .*"
117-
):
118-
path = MyPath(local_testdir)
119-
assert str(path) == pathlib.Path(local_testdir).as_posix()
117+
with pytest.raises(ValueError, match=r".*incompatible with"):
118+
MyPath(local_testdir)
119+
120+
121+
@pytest.fixture(scope="function")
122+
def upath_registry_snapshot():
123+
"""Save and restore the upath registry state around a test."""
124+
from upath.registry import _registry
125+
126+
# Save the current state of the registry's mutable mapping
127+
saved_m = _registry._m.maps[0].copy()
128+
try:
129+
yield
130+
finally:
131+
# Restore the registry state
132+
_registry._m.maps[0].clear()
133+
_registry._m.maps[0].update(saved_m)
134+
get_upath_class.cache_clear()
135+
136+
137+
def test_subclass_registered(upath_registry_snapshot):
138+
class MyPath(UPath):
139+
pass
140+
141+
register_implementation("memory", MyPath, clobber=True)
142+
path = MyPath("memory:///test_path")
143+
assert str(path) == "memory:///test_path"
120144
assert issubclass(MyPath, UPath)
145+
assert isinstance(path, MyPath)
121146
assert isinstance(path, pathlib_abc.ReadablePath)
122147
assert isinstance(path, pathlib_abc.WritablePath)
123148
assert not isinstance(path, pathlib.Path)
@@ -453,33 +478,99 @@ def test_open_a_local_upath(tmp_path, protocol):
453478
@pytest.mark.parametrize(
454479
"uri,protocol",
455480
[
481+
# s3 compatible protocols
456482
("s3://bucket/folder", "s3"),
457-
("gs://bucket/folder", "gs"),
483+
("s3a://bucket/folder", "s3a"),
458484
("bucket/folder", "s3"),
485+
# gcs compatible
486+
("gs://bucket/folder", "gs"),
487+
("gcs://bucket/folder", "gcs"),
488+
("bucket/folder", "gs"),
489+
# azure compatible
490+
("az://container/blob", "az"),
491+
("abfs://container/blob", "abfs"),
492+
("abfss://container/blob", "abfss"),
493+
("adl://container/blob", "adl"),
494+
# memory
459495
("memory://folder", "memory"),
496+
("/folder", "memory"),
497+
# file/local
460498
("file:/tmp/folder", "file"),
461499
("/tmp/folder", "file"),
500+
("file:/tmp/folder", "local"),
501+
("/tmp/folder", "local"),
462502
("/tmp/folder", ""),
463503
("a/b/c", ""),
504+
# http/https
505+
("http://example.com/path", "http"),
506+
("https://example.com/path", "https"),
507+
# ftp
508+
("ftp://example.com/path", "ftp"),
509+
# sftp/ssh
510+
("sftp://example.com/path", "sftp"),
511+
("ssh://example.com/path", "ssh"),
512+
# smb
513+
("smb://server/share/path", "smb"),
514+
# hdfs
515+
("hdfs://namenode/path", "hdfs"),
516+
# webdav - requires base_url, skip for now
517+
# github
518+
("github://owner:repo@branch/path", "github"),
519+
# data
520+
("data:text/plain;base64,SGVsbG8=", "data"),
521+
# huggingface
522+
("hf://datasets/user/repo/path", "hf"),
464523
],
465524
)
466525
def test_constructor_compatible_protocol_uri(uri, protocol):
467526
p = UPath(uri, protocol=protocol)
468527
assert p.protocol == protocol
469528

470529

471-
@pytest.mark.parametrize(
472-
"uri,protocol",
473-
[
474-
("s3://bucket/folder", "gs"),
475-
("gs://bucket/folder", "s3"),
476-
("memory://folder", "s3"),
477-
("file:/tmp/folder", "s3"),
478-
("s3://bucket/folder", ""),
479-
("memory://folder", ""),
480-
("file:/tmp/folder", ""),
481-
],
482-
)
530+
# Protocol to sample URI mapping
531+
_PROTOCOL_URIS = {
532+
"s3": "s3://bucket/folder",
533+
"gs": "gs://bucket/folder",
534+
"az": "az://container/blob",
535+
"memory": "memory://folder",
536+
"file": "file:/tmp/folder",
537+
"http": "http://example.com/path",
538+
"ftp": "ftp://example.com/path",
539+
"sftp": "sftp://example.com/path",
540+
"smb": "smb://server/share/path",
541+
"hdfs": "hdfs://namenode/path",
542+
}
543+
544+
# Generate incompatible combinations: each protocol with URIs from other protocols
545+
_INCOMPATIBLE_CASES = [
546+
(_PROTOCOL_URIS[uri_protocol], target_protocol)
547+
for target_protocol in _PROTOCOL_URIS
548+
for uri_protocol in _PROTOCOL_URIS
549+
if target_protocol != uri_protocol
550+
]
551+
552+
# Also test explicit empty protocol with protocol-prefixed URIs
553+
_INCOMPATIBLE_CASES.extend([(uri, "") for uri in _PROTOCOL_URIS.values()])
554+
555+
556+
@pytest.mark.parametrize("uri,protocol", _INCOMPATIBLE_CASES)
483557
def test_constructor_incompatible_protocol_uri(uri, protocol):
484-
with pytest.raises(ValueError, match=r".*incompatible with"):
558+
with pytest.raises(TypeError, match=r".*incompatible with"):
485559
UPath(uri, protocol=protocol)
560+
561+
562+
# Test subclass instantiation with incompatible URIs
563+
# Use protocols that have registered implementations we can get via get_upath_class
564+
_SUBCLASS_INCOMPATIBLE_CASES = [
565+
(_PROTOCOL_URIS[uri_protocol], target_protocol)
566+
for target_protocol in _PROTOCOL_URIS
567+
for uri_protocol in _PROTOCOL_URIS
568+
if target_protocol != uri_protocol
569+
]
570+
571+
572+
@pytest.mark.parametrize("uri,protocol", _SUBCLASS_INCOMPATIBLE_CASES)
573+
def test_subclass_constructor_incompatible_protocol_uri(uri, protocol):
574+
cls = get_upath_class(protocol)
575+
with pytest.raises(TypeError, match=r".*incompatible with"):
576+
cls(uri)

upath/tests/test_extensions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,34 @@ class MyPath(UPath):
214214
a = MyPath(".", protocol="memory")
215215

216216
assert isinstance(a, MyPath)
217+
218+
219+
# Protocol to sample URI mapping for compatibility tests
220+
_PROTOCOL_URIS = {
221+
"s3": "s3://bucket/folder",
222+
"gs": "gs://bucket/folder",
223+
"memory": "memory://folder",
224+
"file": "file:/tmp/folder",
225+
"http": "http://example.com/path",
226+
"": "/tmp/folder",
227+
}
228+
229+
# Generate incompatible combinations
230+
_PROXY_INCOMPATIBLE_CASES = [
231+
(_PROTOCOL_URIS[uri_protocol], target_protocol)
232+
for target_protocol in _PROTOCOL_URIS
233+
for uri_protocol in _PROTOCOL_URIS
234+
if target_protocol != uri_protocol and uri_protocol != ""
235+
]
236+
237+
238+
@pytest.mark.parametrize("uri,protocol", _PROXY_INCOMPATIBLE_CASES)
239+
def test_proxy_subclass_incompatible_protocol_uri(uri, protocol):
240+
"""Test that ProxyUPath subclasses raise TypeError for incompatible protocols."""
241+
242+
class MyProxyPath(ProxyUPath):
243+
pass
244+
245+
# ProxyUPath wraps the underlying path, so it should also raise TypeError
246+
with pytest.raises(TypeError, match=r".*incompatible with"):
247+
MyProxyPath(uri, protocol=protocol)

0 commit comments

Comments
 (0)