Skip to content

Commit 3d4ec00

Browse files
authored
UPath.joinpath raise error on protocol mismatch (#264)
* tests: add test defining protocol mismatch behavior * upath: fix UPath raise ValueError on mismatch instead of TypeError * upath.implementations.cloud: raise early if bucket/container missing * upath: fix protocol matching on <=3.11
1 parent e2451e9 commit 3d4ec00

File tree

5 files changed

+75
-17
lines changed

5 files changed

+75
-17
lines changed

upath/_protocol.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33
import os
44
import re
55
from pathlib import PurePath
6+
from typing import TYPE_CHECKING
67
from typing import Any
78

9+
if TYPE_CHECKING:
10+
from upath.core import UPath
11+
812
__all__ = [
913
"get_upath_protocol",
1014
"normalize_empty_netloc",
15+
"compatible_protocol",
1116
]
1217

1318
# Regular expression to match fsspec style protocols.
@@ -59,3 +64,15 @@ def normalize_empty_netloc(pth: str) -> str:
5964
path = m.group("path")
6065
pth = f"{protocol}:///{path}"
6166
return pth
67+
68+
69+
def compatible_protocol(protocol: str, *args: str | os.PathLike[str] | UPath) -> bool:
70+
"""check if UPath protocols are compatible"""
71+
for arg in args:
72+
other_protocol = get_upath_protocol(arg)
73+
# consider protocols equivalent if they match up to the first "+"
74+
other_protocol = other_protocol.partition("+")[0]
75+
# protocols: only identical (or empty "") protocols can combine
76+
if other_protocol and other_protocol != protocol:
77+
return False
78+
return True

upath/core.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from upath._flavour import LazyFlavourDescriptor
3636
from upath._flavour import upath_get_kwargs_from_url
3737
from upath._flavour import upath_urijoin
38+
from upath._protocol import compatible_protocol
3839
from upath._protocol import get_upath_protocol
3940
from upath._stat import UPathStatResult
4041
from upath.registry import get_upath_class
@@ -251,23 +252,12 @@ def __init__(
251252
self._storage_options = storage_options.copy()
252253

253254
# check that UPath subclasses in args are compatible
254-
# --> ensures items in _raw_paths are compatible
255-
for arg in args:
256-
if not isinstance(arg, UPath):
257-
continue
258-
# protocols: only identical (or empty "") protocols can combine
259-
if arg.protocol and arg.protocol != self._protocol:
260-
raise TypeError("can't combine different UPath protocols as parts")
261-
# storage_options: args may not define other storage_options
262-
if any(
263-
self._storage_options.get(key) != value
264-
for key, value in arg.storage_options.items()
265-
):
266-
# TODO:
267-
# Future versions of UPath could verify that storage_options
268-
# can be combined between UPath instances. Not sure if this
269-
# is really necessary though. A warning might be enough...
270-
pass
255+
# TODO:
256+
# Future versions of UPath could verify that storage_options
257+
# can be combined between UPath instances. Not sure if this
258+
# is really necessary though. A warning might be enough...
259+
if not compatible_protocol(self._protocol, *args):
260+
raise ValueError("can't combine incompatible UPath protocols")
271261

272262
# fill ._raw_paths
273263
if hasattr(self, "_raw_paths"):

upath/implementations/cloud.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
class CloudPath(UPath):
2323
__slots__ = ()
2424

25+
def __init__(
26+
self, *args, protocol: str | None = None, **storage_options: Any
27+
) -> None:
28+
super().__init__(*args, protocol=protocol, **storage_options)
29+
if not self.drive and len(self.parts) > 1:
30+
raise ValueError("non key-like path provided (bucket/container missing)")
31+
2532
@classmethod
2633
def _transform_init_args(
2734
cls,

upath/implementations/local.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import MutableMapping
1313
from urllib.parse import SplitResult
1414

15+
from upath._protocol import compatible_protocol
1516
from upath.core import UPath
1617

1718
__all__ = [
@@ -141,6 +142,8 @@ def __new__(
141142
raise NotImplementedError(
142143
f"cannot instantiate {cls.__name__} on your system"
143144
)
145+
if not compatible_protocol("", *args):
146+
raise ValueError("can't combine incompatible UPath protocols")
144147
obj = super().__new__(cls, *args)
145148
obj._protocol = ""
146149
return obj # type: ignore[return-value]
@@ -152,6 +155,11 @@ def __init__(
152155
self._drv, self._root, self._parts = type(self)._parse_args(args)
153156
_upath_init(self)
154157

158+
def _make_child(self, args):
159+
if not compatible_protocol(self._protocol, *args):
160+
raise ValueError("can't combine incompatible UPath protocols")
161+
return super()._make_child(args)
162+
155163
@classmethod
156164
def _from_parts(cls, *args, **kwargs):
157165
obj = super(Path, cls)._from_parts(*args, **kwargs)
@@ -205,6 +213,8 @@ def __new__(
205213
raise NotImplementedError(
206214
f"cannot instantiate {cls.__name__} on your system"
207215
)
216+
if not compatible_protocol("", *args):
217+
raise ValueError("can't combine incompatible UPath protocols")
208218
obj = super().__new__(cls, *args)
209219
obj._protocol = ""
210220
return obj # type: ignore[return-value]
@@ -216,6 +226,11 @@ def __init__(
216226
self._drv, self._root, self._parts = self._parse_args(args)
217227
_upath_init(self)
218228

229+
def _make_child(self, args):
230+
if not compatible_protocol(self._protocol, *args):
231+
raise ValueError("can't combine incompatible UPath protocols")
232+
return super()._make_child(args)
233+
219234
@classmethod
220235
def _from_parts(cls, *args, **kwargs):
221236
obj = super(Path, cls)._from_parts(*args, **kwargs)

upath/tests/test_core.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,3 +410,32 @@ def test_query_string(uri, query_str):
410410
p = UPath(uri)
411411
assert str(p).endswith(query_str)
412412
assert p.path.endswith(query_str)
413+
414+
415+
@pytest.mark.parametrize(
416+
"base,join",
417+
[
418+
("/a", "s3://bucket/b"),
419+
("s3://bucket/a", "gs://b/c"),
420+
("gs://bucket/a", "memory://b/c"),
421+
("memory://bucket/a", "s3://b/c"),
422+
],
423+
)
424+
def test_joinpath_on_protocol_mismatch(base, join):
425+
with pytest.raises(ValueError):
426+
UPath(base).joinpath(UPath(join))
427+
with pytest.raises(ValueError):
428+
UPath(base) / UPath(join)
429+
430+
431+
@pytest.mark.parametrize(
432+
"base,join",
433+
[
434+
("/a", "s3://bucket/b"),
435+
("s3://bucket/a", "gs://b/c"),
436+
("gs://bucket/a", "memory://b/c"),
437+
("memory://bucket/a", "s3://b/c"),
438+
],
439+
)
440+
def test_joinuri_on_protocol_mismatch(base, join):
441+
assert UPath(base).joinuri(UPath(join)) == UPath(join)

0 commit comments

Comments
 (0)