diff --git a/upath/_chain.py b/upath/_chain.py index 83d3db5a..23a25741 100644 --- a/upath/_chain.py +++ b/upath/_chain.py @@ -231,14 +231,14 @@ def unchain( proto0 is None or bit == proto0 ): # exact match a fsspec protocol proto = bit - path_bit = "" + path_bit = None extra_so = {} elif bit in (m := set(available_implementations(fallback=True))) and ( proto0 is None or bit == proto0 ): self.known_protocols = m proto = bit - path_bit = "" + path_bit = None extra_so = {} else: proto = get_upath_protocol(bit, protocol=proto0) @@ -246,8 +246,8 @@ def unchain( path_bit = flavour.strip_protocol(bit) extra_so = flavour.get_kwargs_from_url(bit) if proto in {"blockcache", "filecache", "simplecache"}: - if path_bit: - next_path_overwrite = path_bit + if path_bit is not None: + next_path_overwrite = path_bit or "/" path_bit = None elif next_path_overwrite is not None: path_bit = next_path_overwrite diff --git a/upath/_flavour.py b/upath/_flavour.py index 3274bcb2..842a8350 100644 --- a/upath/_flavour.py +++ b/upath/_flavour.py @@ -132,6 +132,7 @@ class WrappedFileSystemFlavour(UPathParser): # (pathlib_abc.FlavourBase) "https", }, "root_marker_override": { + "smb": "/", "ssh": "/", "sftp": "/", }, @@ -253,7 +254,7 @@ def stringify_path(pth: JoinablePathLike) -> str: def strip_protocol(self, pth: JoinablePathLike) -> str: pth = self.stringify_path(pth) - return self._spec._strip_protocol(pth) + return self._spec._strip_protocol(pth) or self.root_marker def get_kwargs_from_url(self, url: JoinablePathLike) -> dict[str, Any]: # NOTE: the public variant is _from_url not _from_urls @@ -317,6 +318,9 @@ def split(self, path: JoinablePathLike) -> tuple[str, str]: tail = stripped_path[1:] elif head: tail = stripped_path[len(head) + 1 :] + elif self.netloc_is_anchor: # and not head + head = stripped_path + tail = "" else: tail = stripped_path if ( diff --git a/upath/core.py b/upath/core.py index 90ac7fef..7c28ba00 100644 --- a/upath/core.py +++ b/upath/core.py @@ -573,7 +573,7 @@ def __init__( # FIXME: normalization needs to happen in unchain already... chain = Chain.from_list(Chain.from_list(segments).to_list()) if len(args) > 1: - flavour = WrappedFileSystemFlavour.from_protocol(protocol) + flavour = WrappedFileSystemFlavour.from_protocol(chain.active_path_protocol) joined = flavour.join(chain.active_path, *args[1:]) stripped = flavour.strip_protocol(joined) chain = chain.replace(path=stripped) @@ -963,7 +963,7 @@ def with_segments(self, *pathsegments: JoinablePathLike) -> Self: new_instance = type(self)( *pathsegments, protocol=self._protocol, - **self._storage_options, + **self.storage_options, ) if hasattr(self, "_fs_cached"): new_instance._fs_cached = self._fs_cached @@ -1090,7 +1090,7 @@ def parent(self) -> Self: self._relative_base, str(self), protocol=self._protocol, - **self._storage_options, + **self.storage_options, ) parent = pth.parent parent._relative_base = self._relative_base @@ -1121,7 +1121,7 @@ def parents(self) -> Sequence[Self]: break parent = parent.parent parents.append(parent) - return parents + return tuple(parents) return super().parents def joinpath(self, *pathsegments: JoinablePathLike) -> Self: @@ -1945,14 +1945,14 @@ def __reduce__(self): args = (self.__vfspath__(),) kwargs = { "protocol": self._protocol, - **self._storage_options, + **self.storage_options, } else: args = (self._relative_base, self.__vfspath__()) # Include _relative_base in the state if it's set kwargs = { "protocol": self._protocol, - **self._storage_options, + **self.storage_options, "_relative_base": self._relative_base, } return _make_instance, (type(self), args, kwargs) diff --git a/upath/implementations/cloud.py b/upath/implementations/cloud.py index 02c89db3..7814a456 100644 --- a/upath/implementations/cloud.py +++ b/upath/implementations/cloud.py @@ -2,6 +2,7 @@ import sys from collections.abc import Iterator +from collections.abc import Sequence from typing import TYPE_CHECKING from typing import Any @@ -82,6 +83,13 @@ def path(self) -> str: return self_path + self.root return self_path + @property + def parts(self) -> Sequence[str]: + parts = super().parts + if self._relative_base is None and len(parts) == 2 and not parts[1]: + return parts[:1] + return parts + def mkdir( self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False ) -> None: @@ -171,6 +179,10 @@ def __init__( *args, protocol=protocol, chain_parser=chain_parser, **storage_options ) + @property + def root(self) -> str: + return "" + def iterdir(self) -> Iterator[Self]: try: yield from super().iterdir() diff --git a/upath/implementations/smb.py b/upath/implementations/smb.py index c967e370..956fdbad 100644 --- a/upath/implementations/smb.py +++ b/upath/implementations/smb.py @@ -42,6 +42,9 @@ def path(self) -> str: path = super().path if len(path) > 1: return path.removesuffix("/") + # At root level, return "/" to match anchor + if not path and self._relative_base is None: + return self.anchor return path def __str__(self) -> str: diff --git a/upath/tests/cases.py b/upath/tests/cases.py index 24485e69..dddb9852 100644 --- a/upath/tests/cases.py +++ b/upath/tests/cases.py @@ -17,6 +17,7 @@ from upath import UPath from upath._protocol import get_upath_protocol from upath._stat import UPathStatResult +from upath.tests.utils import posixify from upath.types import StatResultType @@ -248,6 +249,16 @@ def test_parents_are_absolute(self): is_absolute = [p.is_absolute() for p in self.path.parents] assert all(is_absolute) + def test_parents_end_at_anchor(self): + p = self.path.joinpath("folder1", "file1.txt") + assert p.parents[-1].path == posixify(p.anchor) + + def test_anchor_is_its_own_parent(self): + p = self.path.joinpath("folder1", "file1.txt") + p0 = p.parents[-1] + assert p0.path == posixify(p.anchor) + assert p0.parent.path == posixify(p.anchor) + def test_private_url_attr_in_sync(self): p = self.path p1 = self.path.joinpath("c") diff --git a/upath/tests/implementations/test_data.py b/upath/tests/implementations/test_data.py index b539e30e..a9e2f4b8 100644 --- a/upath/tests/implementations/test_data.py +++ b/upath/tests/implementations/test_data.py @@ -129,6 +129,17 @@ def test_trailing_slash_is_stripped(self): with pytest.raises(UnsupportedOperation): super().test_trailing_slash_is_stripped() + @overrides_base + def test_parents_end_at_anchor(self): + # DataPath does not support joins + with pytest.raises(UnsupportedOperation): + super().test_parents_end_at_anchor() + + @overrides_base + def test_anchor_is_its_own_parent(self): + # DataPath does not support joins + assert self.path.path == self.path.parent.path + @overrides_base def test_private_url_attr_in_sync(self): # DataPath does not support joins, so we check on self.path diff --git a/upath/tests/test_core.py b/upath/tests/test_core.py index 822cc8c7..9d6d9317 100644 --- a/upath/tests/test_core.py +++ b/upath/tests/test_core.py @@ -72,6 +72,22 @@ def test_is_correct_class(self): def test_parents_are_absolute(self): super().test_parents_are_absolute() + @overrides_base + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="mock fs is not well defined on windows", + ) + def test_anchor_is_its_own_parent(self): + super().test_anchor_is_its_own_parent() + + @overrides_base + @pytest.mark.skipif( + sys.platform.startswith("win"), + reason="mock fs is not well defined on windows", + ) + def test_parents_end_at_anchor(self): + super().test_parents_end_at_anchor() + def test_multiple_backend_paths(local_testdir): path = "s3://bucket/"