Skip to content

Commit 86e8dbd

Browse files
authored
Merge pull request #379 from gerritholl/shallow-find
Test for #378 to verify find respects maxdepth.
2 parents f94478e + 07d51a4 commit 86e8dbd

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

s3fs/core.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import os
55
import socket
6-
import time
76
from typing import Tuple, Optional
87
import weakref
98

@@ -204,9 +203,7 @@ async def _call_s3(self, method, *akwarglist, **kwargs):
204203
**kwargs)
205204
for i in range(self.retries):
206205
try:
207-
out = await method(**additional_kwargs)
208-
locals().pop("err", None) # break cycle following retry
209-
return out
206+
return await method(**additional_kwargs)
210207
except S3_RETRYABLE_ERRORS as e:
211208
logger.debug("Retryable error: %s" % e)
212209
err = e
@@ -446,6 +443,9 @@ async def _find(self, path, maxdepth=None, withdirs=None, detail=False):
446443
bucket, key, _ = self.split_path(path)
447444
if not bucket:
448445
raise ValueError("Cannot traverse all of S3")
446+
if maxdepth:
447+
return super().find(bucket + "/" + key, maxdepth=maxdepth, withdirs=withdirs,
448+
detail=detail)
449449
# TODO: implement find from dircache, if all listings are present
450450
# if refresh is False:
451451
# out = incomplete_tree_dirs(self.dircache, path)
@@ -855,7 +855,7 @@ def isdir(self, path):
855855
return False
856856

857857
# This only returns things within the path and NOT the path object itself
858-
return bool(sync(self.loop, self._lsdir, path))
858+
return bool(maybe_sync(self._lsdir, self, path))
859859

860860
def ls(self, path, detail=False, refresh=False, **kwargs):
861861
""" List single "directory" with or without details
@@ -873,9 +873,9 @@ def ls(self, path, detail=False, refresh=False, **kwargs):
873873
additional arguments passed on
874874
"""
875875
path = self._strip_protocol(path).rstrip('/')
876-
files = sync(self.loop, self._ls, path, refresh=refresh)
876+
files = maybe_sync(self._ls, self, path, refresh=refresh)
877877
if not files:
878-
files = sync(self.loop, self._ls, self._parent(path), refresh=refresh)
878+
files = maybe_sync(self._ls, self, self._parent(path), refresh=refresh)
879879
files = [o for o in files if o['name'].rstrip('/') == path
880880
and o['type'] != 'directory']
881881
if detail:
@@ -1080,7 +1080,7 @@ def url(self, path, expires=3600, **kwargs):
10801080
the number of seconds this signature will be good for.
10811081
"""
10821082
bucket, key, version_id = self.split_path(path)
1083-
return sync(self.loop, self.s3.generate_presigned_url,
1083+
return maybe_sync(self.s3.generate_presigned_url, self,
10841084
ClientMethod='get_object',
10851085
Params=dict(Bucket=bucket, Key=key, **version_id_kw(version_id), **kwargs),
10861086
ExpiresIn=expires)
@@ -1274,7 +1274,7 @@ def rm(self, path, recursive=False, **kwargs):
12741274
bucket, key, _ = self.split_path(path)
12751275
if not key and self.is_bucket_versioned(bucket):
12761276
# special path to completely remove versioned bucket
1277-
sync(self.loop, self._rm_versioned_bucket_contents, bucket)
1277+
maybe_sync(self._rm_versioned_bucket_contents, self, bucket)
12781278
super().rm(path, recursive=recursive, **kwargs)
12791279

12801280
def invalidate_cache(self, path=None):

s3fs/tests/test_s3fs.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1679,3 +1679,15 @@ async def async_wrapper():
16791679
await s3._s3.close()
16801680

16811681
asyncio.run(_())
1682+
1683+
1684+
def test_shallow_find(s3):
1685+
"""Test that find method respects maxdepth.
1686+
1687+
Verify that the ``find`` method respects the ``maxdepth`` parameter. With
1688+
``maxdepth=1``, the results of ``find`` should be the same as those of
1689+
``ls``, without returning subdirectories. See also issue 378.
1690+
"""
1691+
1692+
assert s3.ls(test_bucket_name) == s3.find(test_bucket_name, maxdepth=1, withdirs=True)
1693+
assert s3.ls(test_bucket_name) == s3.glob(test_bucket_name + "/*")

0 commit comments

Comments
 (0)