Skip to content

Commit 009f5cf

Browse files
authored
Fix sftp join issue for non-root prefixed paths (#294)
* tests: add test reproducing join issue for non-root prefixed paths * tests: improve ssh container fixture robustness * tests: add additional test cases for constructing ssh paths * upath._flavour: fix ssh path parsing (enforce absolute paths) * typing: fix error in kwargs
1 parent 3fc72fb commit 009f5cf

File tree

3 files changed

+63
-14
lines changed

3 files changed

+63
-14
lines changed

upath/_flavour.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ProtocolConfig(TypedDict):
7878
netloc_is_anchor: set[str]
7979
supports_empty_parts: set[str]
8080
meaningful_trailing_slash: set[str]
81+
root_marker_override: dict[str, str]
8182

8283

8384
class WrappedFileSystemFlavour: # (pathlib_abc.FlavourBase)
@@ -109,8 +110,6 @@ class WrappedFileSystemFlavour: # (pathlib_abc.FlavourBase)
109110
"https",
110111
"s3",
111112
"s3a",
112-
"sftp",
113-
"ssh",
114113
"smb",
115114
"gs",
116115
"gcs",
@@ -135,6 +134,10 @@ class WrappedFileSystemFlavour: # (pathlib_abc.FlavourBase)
135134
"http",
136135
"https",
137136
},
137+
"root_marker_override": {
138+
"ssh": "/",
139+
"sftp": "/",
140+
},
138141
}
139142

140143
def __init__(
@@ -144,6 +147,7 @@ def __init__(
144147
netloc_is_anchor: bool = False,
145148
supports_empty_parts: bool = False,
146149
meaningful_trailing_slash: bool = False,
150+
root_marker_override: str | None = None,
147151
) -> None:
148152
"""initialize the flavour with the given fsspec"""
149153
self._spec = spec
@@ -163,6 +167,12 @@ def __init__(
163167
# - UPath._parse_path
164168
self.has_meaningful_trailing_slash = bool(meaningful_trailing_slash)
165169

170+
# some filesystems require UPath to enforce a specific root marker
171+
if root_marker_override is None:
172+
self.root_marker_override = None
173+
else:
174+
self.root_marker_override = str(root_marker_override)
175+
166176
@classmethod
167177
@lru_cache(maxsize=None)
168178
def from_protocol(
@@ -172,10 +182,11 @@ def from_protocol(
172182
"""return the fsspec flavour for the given protocol"""
173183

174184
_c = cls.protocol_config
175-
config = {
185+
config: dict[str, Any] = {
176186
"netloc_is_anchor": protocol in _c["netloc_is_anchor"],
177187
"supports_empty_parts": protocol in _c["supports_empty_parts"],
178188
"meaningful_trailing_slash": protocol in _c["meaningful_trailing_slash"],
189+
"root_marker_override": _c["root_marker_override"].get(protocol),
179190
}
180191

181192
# first try to get an already imported fsspec filesystem class
@@ -217,7 +228,10 @@ def protocol(self) -> tuple[str, ...]:
217228

218229
@property
219230
def root_marker(self) -> str:
220-
return self._spec.root_marker
231+
if self.root_marker_override is not None:
232+
return self.root_marker_override
233+
else:
234+
return self._spec.root_marker
221235

222236
@property
223237
def local_file(self) -> bool:
@@ -278,7 +292,7 @@ def join(self, path: PathOrStr, *paths: PathOrStr) -> str:
278292
drv, p0 = self.splitdrive(path)
279293
p0 = p0 or self.sep
280294
else:
281-
p0 = str(self.strip_protocol(path))
295+
p0 = str(self.strip_protocol(path)) or self.root_marker
282296
pN = list(map(self.stringify_path, paths))
283297
drv = ""
284298
if self.supports_empty_parts:

upath/tests/conftest.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,6 @@ def ssh_container():
486486
)
487487
try:
488488
subprocess.run(shlex.split(cmd))
489-
time.sleep(1)
490489
yield {
491490
"host": "localhost",
492491
"port": 2222,
@@ -499,7 +498,7 @@ def ssh_container():
499498

500499
@pytest.fixture
501500
def ssh_fixture(ssh_container, local_testdir, monkeypatch):
502-
pytest.importorskip("paramiko", reason="sftp tests require paramiko")
501+
paramiko = pytest.importorskip("paramiko", reason="sftp tests require paramiko")
503502

504503
cls = fsspec.get_filesystem_class("ssh")
505504
if cls.put != fsspec.AbstractFileSystem.put:
@@ -509,13 +508,28 @@ def ssh_fixture(ssh_container, local_testdir, monkeypatch):
509508

510509
monkeypatch.setattr(_DEFAULT_CALLBACK, "relative_update", lambda *args: None)
511510

512-
fs = fsspec.filesystem(
513-
"ssh",
514-
host=ssh_container["host"],
515-
port=ssh_container["port"],
516-
username=ssh_container["username"],
517-
password=ssh_container["password"],
518-
)
511+
for _ in range(100):
512+
try:
513+
fs = fsspec.filesystem(
514+
"ssh",
515+
host=ssh_container["host"],
516+
port=ssh_container["port"],
517+
username=ssh_container["username"],
518+
password=ssh_container["password"],
519+
timeout=10.0,
520+
banner_timeout=30.0,
521+
skip_instance_cache=True,
522+
)
523+
except (
524+
paramiko.ssh_exception.NoValidConnectionsError,
525+
paramiko.ssh_exception.SSHException,
526+
):
527+
time.sleep(0.1)
528+
continue
529+
break
530+
else:
531+
raise RuntimeError("issue with openssh-container startup")
532+
519533
fs.put(local_testdir, "/app/testdir", recursive=True)
520534
try:
521535
yield "ssh://{username}:{password}@{host}:{port}/app/testdir/".format(

upath/tests/implementations/test_sftp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,24 @@ def test_mkdir_parents_true_exists_ok_false(self):
3838
@_xfail_old_fsspec
3939
def test_mkdir_parents_true_exists_ok_true(self):
4040
super().test_mkdir_parents_true_exists_ok_true()
41+
42+
43+
@pytest.mark.parametrize(
44+
"args,parts",
45+
[
46+
(("sftp://user@host",), ("/",)),
47+
(("sftp://user@host/",), ("/",)),
48+
(("sftp://user@host", ""), ("/",)),
49+
(("sftp://user@host/", ""), ("/",)),
50+
(("sftp://user@host", "/"), ("/",)),
51+
(("sftp://user@host/", "/"), ("/",)),
52+
(("sftp://user@host/abc",), ("/", "abc")),
53+
(("sftp://user@host", "abc"), ("/", "abc")),
54+
(("sftp://user@host", "/abc"), ("/", "abc")),
55+
(("sftp://user@host/", "/abc"), ("/", "abc")),
56+
],
57+
)
58+
def test_join_produces_correct_parts(args, parts):
59+
pth = UPath(*args)
60+
assert pth.storage_options == {"host": "host", "username": "user"}
61+
assert pth.parts == parts

0 commit comments

Comments
 (0)