Skip to content

Commit 0e46a48

Browse files
lhoestqWauplin
authored andcommitted
[HfFileSystem] Optimize maxdepth: do less /tree calls in glob() (#3389)
* optimize maxdepth * raise an error on invalid maxdepth * mypy * test * fix test
1 parent f46eb28 commit 0e46a48

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

src/huggingface_hub/hf_file_system.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ def _ls_tree(
386386
refresh: bool = False,
387387
revision: Optional[str] = None,
388388
expand_info: bool = False,
389+
maxdepth: Optional[int] = None,
389390
):
390391
resolved_path = self.resolve_path(path, revision=revision)
391392
path = resolved_path.unresolve()
@@ -405,19 +406,25 @@ def _ls_tree(
405406
if recursive:
406407
# Use BFS to traverse the cache and build the "recursive "output
407408
# (The Hub uses a so-called "tree first" strategy for the tree endpoint but we sort the output to follow the spec so the result is (eventually) the same)
409+
depth = 2
408410
dirs_to_visit = deque(
409-
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
411+
[(depth, path_info) for path_info in cached_path_infos if path_info["type"] == "directory"]
410412
)
411413
while dirs_to_visit:
412-
dir_info = dirs_to_visit.popleft()
413-
if dir_info["name"] not in self.dircache:
414-
dirs_not_in_dircache.append(dir_info["name"])
415-
else:
416-
cached_path_infos = self.dircache[dir_info["name"]]
417-
out.extend(cached_path_infos)
418-
dirs_to_visit.extend(
419-
[path_info for path_info in cached_path_infos if path_info["type"] == "directory"]
420-
)
414+
depth, dir_info = dirs_to_visit.popleft()
415+
if maxdepth is None or depth <= maxdepth:
416+
if dir_info["name"] not in self.dircache:
417+
dirs_not_in_dircache.append(dir_info["name"])
418+
else:
419+
cached_path_infos = self.dircache[dir_info["name"]]
420+
out.extend(cached_path_infos)
421+
dirs_to_visit.extend(
422+
[
423+
(depth + 1, path_info)
424+
for path_info in cached_path_infos
425+
if path_info["type"] == "directory"
426+
]
427+
)
421428

422429
dirs_not_expanded = []
423430
if expand_info:
@@ -436,6 +443,9 @@ def _ls_tree(
436443
or common_prefix in chain(dirs_not_in_dircache, dirs_not_expanded)
437444
else self._parent(common_prefix)
438445
)
446+
if maxdepth is not None:
447+
common_path_depth = common_path[len(path) :].count("/")
448+
maxdepth -= common_path_depth
439449
out = [o for o in out if not o["name"].startswith(common_path + "/")]
440450
for cached_path in self.dircache:
441451
if cached_path.startswith(common_path + "/"):
@@ -448,6 +458,7 @@ def _ls_tree(
448458
refresh=True,
449459
revision=revision,
450460
expand_info=expand_info,
461+
maxdepth=maxdepth,
451462
)
452463
)
453464
else:
@@ -460,9 +471,10 @@ def _ls_tree(
460471
repo_type=resolved_path.repo_type,
461472
)
462473
for path_info in tree:
474+
cache_path = root_path + "/" + path_info.path
463475
if isinstance(path_info, RepoFile):
464476
cache_path_info = {
465-
"name": root_path + "/" + path_info.path,
477+
"name": cache_path,
466478
"size": path_info.size,
467479
"type": "file",
468480
"blob_id": path_info.blob_id,
@@ -472,15 +484,17 @@ def _ls_tree(
472484
}
473485
else:
474486
cache_path_info = {
475-
"name": root_path + "/" + path_info.path,
487+
"name": cache_path,
476488
"size": 0,
477489
"type": "directory",
478490
"tree_id": path_info.tree_id,
479491
"last_commit": path_info.last_commit,
480492
}
481493
parent_path = self._parent(cache_path_info["name"])
482494
self.dircache.setdefault(parent_path, []).append(cache_path_info)
483-
out.append(cache_path_info)
495+
depth = cache_path[len(path) :].count("/")
496+
if maxdepth is None or depth <= maxdepth:
497+
out.append(cache_path_info)
484498
return out
485499

486500
def walk(self, path: str, *args, **kwargs) -> Iterator[Tuple[str, List[str], List[str]]]:
@@ -547,19 +561,22 @@ def find(
547561
Returns:
548562
`Union[List[str], Dict[str, Dict[str, Any]]]`: List of paths or dict of file information.
549563
"""
550-
if maxdepth:
551-
return super().find(
552-
path, maxdepth=maxdepth, withdirs=withdirs, detail=detail, refresh=refresh, revision=revision, **kwargs
553-
)
564+
if maxdepth is not None and maxdepth < 1:
565+
raise ValueError("maxdepth must be at least 1")
554566
resolved_path = self.resolve_path(path, revision=revision)
555567
path = resolved_path.unresolve()
556568
try:
557-
out = self._ls_tree(path, recursive=True, refresh=refresh, revision=resolved_path.revision, **kwargs)
569+
out = self._ls_tree(
570+
path, recursive=True, refresh=refresh, revision=resolved_path.revision, maxdepth=maxdepth, **kwargs
571+
)
558572
except EntryNotFoundError:
559573
# Path could be a file
560-
if self.info(path, revision=revision, **kwargs)["type"] == "file":
561-
out = {path: {}}
562-
else:
574+
try:
575+
if self.info(path, revision=revision, **kwargs)["type"] == "file":
576+
out = {path: {}}
577+
else:
578+
out = {}
579+
except FileNotFoundError:
563580
out = {}
564581
else:
565582
if not withdirs:

tests/test_hf_file_system.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def setUp(self):
6060
repo_type="dataset",
6161
)
6262

63-
self.text_file = self.hf_path + "/data/text_data.txt"
63+
self.text_file_path_in_repo = "data/text_data.txt"
64+
self.text_file = self.hf_path + "/" + self.text_file_path_in_repo
6465

6566
def tearDown(self):
6667
self.api.delete_repo(self.repo_id, repo_type="dataset")
@@ -421,6 +422,16 @@ def test_find_data_file_no_revision(self):
421422
files = self.hffs.find(self.text_file, detail=False)
422423
self.assertEqual(files, [self.text_file])
423424

425+
def test_find_maxdepth(self):
426+
text_file_depth = self.text_file_path_in_repo.count("/") + 1
427+
files = self.hffs.find(self.hf_path, detail=False, maxdepth=text_file_depth - 1)
428+
self.assertNotIn(self.text_file, files)
429+
files = self.hffs.find(self.hf_path, detail=False, maxdepth=text_file_depth)
430+
self.assertIn(self.text_file, files)
431+
# we do it again once the cache is updated
432+
files = self.hffs.find(self.hf_path, detail=False, maxdepth=text_file_depth - 1)
433+
self.assertNotIn(self.text_file, files)
434+
424435
def test_read_bytes(self):
425436
data = self.hffs.read_bytes(self.text_file)
426437
self.assertEqual(data, b"dummy text data")

0 commit comments

Comments
 (0)