Skip to content

Commit 36b964e

Browse files
authored
Fix File path handling across storage backends (#1604)
* File path handling cleanup * fix tests * escape listing dataset names properly * fix tests * fix(local): make is_path_in cross-platform and add drive-letter check Replace Path.resolve(strict=False) with os.path.abspath + os.path.normcase in is_path_in() for deterministic, filesystem-independent path comparison. resolve() varies across Python versions (3.11 uses abspath, 3.12+ uses realpath) and can produce different results depending on which path components already exist on disk. Also use startswith(output + os.sep) instead of is_relative_to() to prevent false positives where /foo/bar2 would match /foo/bar. Add drive-letter absolute path check to validate_local_relpath() so Windows paths like C:/secret/file are rejected. * review and refactor export, fix tests again * normalize only local windows paths * normalize only local windows paths * revert unrelated changes * review, cleanup * cleanup / fix listings handling * remove get_full_path from the client * remove rel path * simplify, unify validation * more cleanup, refactor to address review * address PR review * prefetch skip and log bad files * add more test for func on path * review, cleanup, add tests for save / upload * more review fixes * address one more review
1 parent a8a23e6 commit 36b964e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3770
-809
lines changed

src/datachain/cache.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,22 @@ async def download(
7878
) -> None:
7979
from dvc_objects.fs.utils import tmp_fname
8080

81-
from_path = file.get_uri()
81+
from_path = file.get_fs_path()
8282
odb_fs = self.odb.fs
8383
tmp_info = odb_fs.join(self.odb.tmp_dir, tmp_fname()) # type: ignore[arg-type]
8484
size = file.size
8585
if size < 0:
86-
size = await client.get_size(from_path, version_id=file.version)
86+
size = await client.get_size(file)
8787
from datachain.progress import tqdm
8888

8989
cb = callback or TqdmCallback(
90-
tqdm_kwargs={"desc": odb_fs.name(from_path), "bytes": True, "leave": False},
90+
tqdm_kwargs={
91+
"desc": odb_fs.name(from_path),
92+
"unit": "B",
93+
"unit_scale": True,
94+
"unit_divisor": 1024,
95+
"leave": False,
96+
},
9197
tqdm_cls=tqdm,
9298
size=size,
9399
)

src/datachain/catalog/catalog.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ def find_column_to_str( # noqa: PLR0911
479479
full_path = path + "/"
480480
else:
481481
full_path = path
482-
return src.get_node_full_path_from_path(full_path)
482+
return src.get_node_uri_from_path(full_path)
483483
if column == "size":
484484
return str(row[field_lookup["size"]])
485485
if column == "type":
@@ -2114,7 +2114,7 @@ def du_dirs(src, node, subdepth):
21142114
for sd in subdirs:
21152115
yield from du_dirs(src, sd, subdepth - 1)
21162116
yield (
2117-
src.get_node_full_path(node),
2117+
src.get_node_uri(node),
21182118
src.listing.du(node)[0],
21192119
)
21202120

src/datachain/catalog/datasource.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ def __init__(self, listing, client, node, as_container=False):
1212
as_container # Indicates whether a .tar file is handled as a container
1313
)
1414

15-
def get_node_full_path(self, node):
16-
return self.client.get_full_path(node.full_path)
15+
def get_node_uri(self, node):
16+
return self.client.get_uri(node.full_path)
1717

18-
def get_node_full_path_from_path(self, full_path):
19-
return self.client.get_full_path(full_path)
18+
def get_node_uri_from_path(self, full_path):
19+
return self.client.get_uri(full_path)
2020

2121
def is_single_object(self):
2222
return self.node.dir_type == DirType.FILE or (

src/datachain/client/azure.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Any
2-
from urllib.parse import parse_qs, urlsplit, urlunsplit
32

43
from adlfs import AzureBlobFileSystem
54

@@ -26,19 +25,48 @@ def info_to_file(self, v: dict[str, Any], path: str) -> File:
2625
size=v.get("size", ""),
2726
)
2827

29-
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
28+
def url(
29+
self,
30+
path: str,
31+
expires: int = 3600,
32+
version_id: str | None = None,
33+
**kwargs,
34+
) -> str:
3035
"""
3136
Generate a signed URL for the given path.
3237
"""
33-
version_id = kwargs.pop("version_id", None)
3438
content_disposition = kwargs.pop("content_disposition", None)
39+
full_path = self.get_uri(path)
40+
if version_id:
41+
# adlfs.split_path() reads version_id from the urlpath.
42+
full_path = f"{full_path}?versionid={version_id}"
43+
3544
result = self.fs.sign(
36-
self.get_full_path(path, version_id),
45+
full_path,
3746
expiration=expires,
3847
content_disposition=content_disposition,
3948
**kwargs,
4049
)
41-
return result + (f"&versionid={version_id}" if version_id else "")
50+
51+
if version_id:
52+
# The Azure SDK does not embed versionid in the SAS token, so we
53+
# append it explicitly to route the request to the correct version.
54+
result += f"&versionid={version_id}"
55+
return result
56+
57+
async def get_file(
58+
self,
59+
lpath: str,
60+
rpath: str,
61+
callback,
62+
version_id: str | None = None,
63+
) -> None:
64+
if version_id:
65+
# adlfs._get_file() only reads version_id from split_path(rpath);
66+
# it does not accept version_id as a kwarg. Embed it in the path
67+
# so split_path can recover it on the adlfs side.
68+
lpath = f"{lpath}?versionid={version_id}"
69+
await self.fs._get_file(lpath, rpath, callback=callback)
4270

4371
async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> None:
4472
prefix = start_prefix
@@ -72,13 +100,4 @@ async def _fetch_flat(self, start_prefix: str, result_queue: ResultQueue) -> Non
72100
finally:
73101
result_queue.put_nowait(None)
74102

75-
@classmethod
76-
def version_path(cls, path: str, version_id: str | None) -> str:
77-
parts = list(urlsplit(path))
78-
query = parse_qs(parts[3])
79-
if "versionid" in query:
80-
raise ValueError("path already includes a version query")
81-
parts[3] = f"versionid={version_id}" if version_id else ""
82-
return urlunsplit(parts)
83-
84103
_fetch_default = _fetch_flat

src/datachain/client/fsspec.py

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import posixpath
77
import re
8-
import sys
98
from abc import ABC, abstractmethod
109
from collections.abc import AsyncIterator, Iterator, Sequence
1110
from datetime import datetime
@@ -19,6 +18,7 @@
1918

2019
from datachain.cache import Cache
2120
from datachain.client.fileslice import FileWrapper
21+
from datachain.fs.utils import is_win_local_path
2222
from datachain.nodes_fetcher import NodesFetcher
2323
from datachain.nodes_thread_pool import NodeChunk
2424
from datachain.progress import tqdm
@@ -41,20 +41,6 @@
4141
ResultQueue = asyncio.Queue[Sequence["File"] | None]
4242

4343

44-
def is_win_local_path(uri: str) -> bool:
45-
if sys.platform == "win32":
46-
if len(uri) >= 1 and uri[0] == "\\":
47-
return True
48-
if (
49-
len(uri) >= 3
50-
and uri[1] == ":"
51-
and (uri[2] == "/" or uri[2] == "\\")
52-
and uri[0].isalpha()
53-
):
54-
return True
55-
return False
56-
57-
5844
def is_cloud_uri(uri: str) -> bool:
5945
protocol = urlparse(uri).scheme
6046
return protocol in CLOUD_STORAGE_PROTOCOLS
@@ -82,7 +68,7 @@ def __init__(self, name: str, fs_kwargs: dict[str, Any], cache: Cache) -> None:
8268
self.fs_kwargs = fs_kwargs
8369
self._fs: AbstractFileSystem | None = None
8470
self.cache = cache
85-
self.uri = self.get_uri(self.name)
71+
self.uri = self.storage_uri(self.name)
8672

8773
@staticmethod
8874
def get_implementation(url: str | os.PathLike[str]) -> type["Client"]: # noqa: PLR0911
@@ -114,29 +100,44 @@ def get_implementation(url: str | os.PathLike[str]) -> type["Client"]: # noqa:
114100

115101
raise NotImplementedError(f"Unsupported protocol: {protocol}")
116102

117-
@classmethod
118-
def path_to_uri(cls, path: str) -> str:
119-
"""Convert a path-like object to a URI. Default: identity."""
120-
return path
121-
122103
@staticmethod
123104
def is_data_source_uri(name: str) -> bool:
124105
# Returns True if name is one of supported data sources URIs, e.g s3 bucket
125106
return DATA_SOURCE_URI_PATTERN.match(name) is not None
126107

108+
@staticmethod
109+
def validate_file_path(path: str) -> None:
110+
"""Validate a relative object path for this backend.
111+
112+
Raises ``ValueError`` for paths that are empty, end with ``/``, or
113+
contain ``.`` / ``..`` segments. Subclasses extend this with
114+
backend-specific rules (e.g. local-filesystem restrictions).
115+
"""
116+
if not path:
117+
raise ValueError("path must not be empty")
118+
if path.endswith("/"):
119+
raise ValueError("path must not be a directory")
120+
parts = path.split("/")
121+
if any(part in (".", "..") for part in parts):
122+
raise ValueError("path must not contain '.' or '..'")
123+
124+
@classmethod # noqa: B027
125+
def validate_source(cls, source: str) -> None:
126+
"""Validate the source URI for this backend."""
127+
127128
@staticmethod
128129
def parse_url(source: str) -> tuple["StorageURI", str]:
129130
cls = Client.get_implementation(source)
130131
storage_name, rel_path = cls.split_url(source)
131-
return cls.get_uri(storage_name), rel_path
132+
return cls.storage_uri(storage_name), rel_path
132133

133134
@staticmethod
134135
def get_client(source: str | os.PathLike[str], cache: Cache, **kwargs) -> "Client":
135136
cls = Client.get_implementation(source)
136137
storage_url, _ = cls.split_url(os.fspath(source))
137138
if os.name == "nt":
138139
storage_url = storage_url.removeprefix("/")
139-
140+
cls.validate_source(os.fspath(source))
140141
return cls.from_name(storage_url, cache, kwargs)
141142

142143
@classmethod
@@ -146,10 +147,6 @@ def create_fs(cls, **kwargs) -> "AbstractFileSystem":
146147
fs.invalidate_cache()
147148
return fs
148149

149-
@classmethod
150-
def version_path(cls, path: str, version_id: str | None) -> str:
151-
return path
152-
153150
@classmethod
154151
def from_name(
155152
cls,
@@ -185,10 +182,11 @@ def is_root_url(cls, url) -> bool:
185182
return url == cls.PREFIX
186183

187184
@classmethod
188-
def get_uri(cls, name: str) -> "StorageURI":
185+
def storage_uri(cls, storage_name: str) -> "StorageURI":
186+
"""Build a :class:`StorageURI` by prepending the protocol to *storage_name*."""
189187
from datachain.dataset import StorageURI
190188

191-
return StorageURI(f"{cls.PREFIX}{name}")
189+
return StorageURI(f"{cls.PREFIX}{storage_name}")
192190

193191
@classmethod
194192
def split_url(cls, url: str) -> tuple[str, str]:
@@ -210,38 +208,47 @@ def fs(self) -> "AbstractFileSystem":
210208
self._fs = self.create_fs(**self.fs_kwargs)
211209
return self._fs
212210

213-
def url(self, path: str, expires: int = 3600, **kwargs) -> str:
214-
return self.fs.sign(
215-
self.get_full_path(path, kwargs.pop("version_id", None)),
216-
expiration=expires,
217-
**kwargs,
218-
)
211+
def url(
212+
self,
213+
path: str,
214+
expires: int = 3600,
215+
version_id: str | None = None,
216+
**kwargs,
217+
) -> str:
218+
self.validate_file_path(path)
219+
kwargs.update(self._version_kwargs(version_id))
220+
return self.fs.sign(self.get_uri(path), expiration=expires, **kwargs)
219221

220222
async def get_current_etag(self, file: "File") -> str:
221-
file_path = file.get_path_normalized()
222-
kwargs = {}
223-
if self._is_version_aware():
224-
kwargs["version_id"] = file.version
225-
info = await self.fs._info(
226-
self.get_full_path(file_path, file.version), **kwargs
227-
)
228-
return self.info_to_file(info, file_path).etag
223+
full_path = file.get_fs_path()
224+
info = await self.fs._info(full_path, **self._version_kwargs(file.version))
225+
return self.info_to_file(info, file.path).etag
229226

230227
def get_file_info(self, path: str, version_id: str | None = None) -> "File":
231-
info = self.fs.info(self.get_full_path(path, version_id), version_id=version_id)
228+
self.validate_file_path(path)
229+
full_path = self.get_uri(path)
230+
info = sync(
231+
get_loop(),
232+
self.fs._info,
233+
full_path,
234+
**self._version_kwargs(version_id),
235+
)
232236
return self.info_to_file(info, path)
233237

234-
async def get_size(self, path: str, version_id: str | None = None) -> int:
235-
return await self.fs._size(
236-
self.version_path(path, version_id), version_id=version_id
237-
)
238+
async def get_size(self, file: "File") -> int:
239+
full_path = file.get_fs_path()
240+
info = await self.fs._info(full_path, **self._version_kwargs(file.version))
241+
size = info.get("size")
242+
if size is None:
243+
raise FileNotFoundError(full_path)
244+
return int(size)
238245

239246
async def get_file(self, lpath, rpath, callback, version_id: str | None = None):
240247
return await self.fs._get_file(
241-
self.version_path(lpath, version_id),
248+
lpath,
242249
rpath,
243250
callback=callback,
244-
version_id=version_id,
251+
**self._version_kwargs(version_id),
245252
)
246253

247254
async def scandir(
@@ -341,6 +348,11 @@ def _is_valid_key(key: str) -> bool:
341348
def _is_version_aware(self) -> bool:
342349
return getattr(self.fs, "version_aware", False)
343350

351+
def _version_kwargs(self, version_id: str | None) -> dict[str, Any]:
352+
if version_id:
353+
return {"version_id": version_id}
354+
return {}
355+
344356
async def ls_dir(self, path):
345357
kwargs = {}
346358
if self._is_version_aware():
@@ -350,8 +362,9 @@ async def ls_dir(self, path):
350362
def rel_path(self, path: str) -> str:
351363
return self.fs.split_path(path)[1]
352364

353-
def get_full_path(self, rel_path: str, version_id: str | None = None) -> str:
354-
return self.version_path(f"{self.PREFIX}{self.name}/{rel_path}", version_id)
365+
def get_uri(self, rel_path: str) -> str:
366+
"""Build a full URI for the given relative path within this client's storage."""
367+
return f"{self.PREFIX}{self.name}/{rel_path}"
355368

356369
@abstractmethod
357370
def info_to_file(self, v: dict[str, Any], path: str) -> "File": ...
@@ -397,20 +410,29 @@ def open_object(
397410
if use_cache and (cache_path := self.cache.get_path(file)):
398411
return open(cache_path, mode="rb")
399412
assert not file.location
413+
kwargs = self._version_kwargs(file.version)
414+
full_path = file.get_fs_path()
400415
return FileWrapper(
401-
self.fs.open(self.get_full_path(file.get_path_normalized(), file.version)),
416+
self.fs.open(full_path, **kwargs),
402417
cb,
403418
) # type: ignore[return-value]
404419

405420
def upload(self, data: bytes, path: str) -> "File":
406-
full_path = path if path.startswith(self.PREFIX) else self.get_full_path(path)
421+
if path.startswith(self.PREFIX):
422+
full_path = path
423+
_, rel_path = self.split_url(path)
424+
else:
425+
rel_path = path
426+
full_path = self.get_uri(path)
427+
428+
self.validate_file_path(rel_path)
407429

408430
parent = posixpath.dirname(full_path)
409431
self.fs.makedirs(parent, exist_ok=True)
410432

411433
self.fs.pipe_file(full_path, data)
412434
file_info = self.fs.info(full_path)
413-
return self.info_to_file(file_info, path)
435+
return self.info_to_file(file_info, rel_path)
414436

415437
def download(self, file: "File", *, callback: Callback = DEFAULT_CALLBACK) -> None:
416438
sync(get_loop(), functools.partial(self._download, file, callback=callback))

0 commit comments

Comments
 (0)