From 1aec2f138b6472302621d0bdac11e2e9f5998fa4 Mon Sep 17 00:00:00 2001 From: Kari Noriy Date: Tue, 6 Jun 2023 15:44:41 +0100 Subject: [PATCH 1/2] mod: added option to recursively walk tree to find all files --- torchdata/datapipes/iter/load/fsspec.py | 42 ++++++++++++++----------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py index 39a875a94..36069e73c 100644 --- a/torchdata/datapipes/iter/load/fsspec.py +++ b/torchdata/datapipes/iter/load/fsspec.py @@ -43,6 +43,7 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): Args: root: The root `fsspec` path directory or list of path directories to list files from masks: Unix style filter string or string list for filtering file name(s) + recursive: If True, recursively traverse the directory tree. If False, list files only in the root directory. kwargs: Extra options that make sense to a particular storage connection, e.g. host, port, username, password, etc. @@ -62,6 +63,7 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): def __init__( self, root: Union[str, Sequence[str], IterDataPipe], + recursive: bool = False, masks: Union[str, List[str]] = "", **kwargs, ) -> None: @@ -77,6 +79,7 @@ def __init__( self.datapipe = root self.masks = masks self.kwargs_for_connection = kwargs + self.recursive = recursive def __iter__(self) -> Iterator[str]: for root in self.datapipe: @@ -92,33 +95,34 @@ def __iter__(self) -> Iterator[str]: protocol_list.append("az") is_local = fs.protocol == "file" or not any(root.startswith(protocol) for protocol in protocol_list) - if fs.isfile(path): - yield root + + if self.recursive: + for current_path, dirs, files in fs.walk(path): + for file_name in files: + if not match_masks(file_name, self.masks): + continue + + abs_path = os.path.join(current_path, file_name) if is_local else posixpath.join(current_path, file_name) + + if any(file_name.startswith(protocol) for protocol in protocol_list): + yield file_name + elif root.startswith(tuple(protocol_list)): + yield protocol_list[0] + "://" + abs_path + else: + yield abs_path else: for file_name in fs.ls(path, detail=False): # Ensure it returns List[str], not List[Dict] if not match_masks(file_name, self.masks): continue - # ensure the file name has the full fsspec protocol path + abs_path = os.path.join(path, file_name) if is_local else posixpath.join(path, file_name) + if any(file_name.startswith(protocol) for protocol in protocol_list): yield file_name + elif root.startswith(tuple(protocol_list)): + yield protocol_list[0] + "://" + abs_path else: - if is_local: - abs_path = os.path.join(path, file_name) - elif not file_name.startswith(path): - abs_path = posixpath.join(path, file_name) - else: - abs_path = file_name - - starts_with = False - for protocol in protocol_list: - if root.startswith(protocol): - starts_with = True - yield protocol + "://" + abs_path - break - - if not starts_with: - yield abs_path + yield abs_path @functional_datapipe("open_files_by_fsspec") From 80c728a0726d664bf8408b9bbfc45ac4d0ae81c8 Mon Sep 17 00:00:00 2001 From: Kari Noriy Date: Tue, 6 Jun 2023 16:25:50 +0100 Subject: [PATCH 2/2] mod: changed match_mask to also consider folder structure --- torchdata/datapipes/iter/load/fsspec.py | 22 +++++++--------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/torchdata/datapipes/iter/load/fsspec.py b/torchdata/datapipes/iter/load/fsspec.py index 36069e73c..8747b519b 100644 --- a/torchdata/datapipes/iter/load/fsspec.py +++ b/torchdata/datapipes/iter/load/fsspec.py @@ -48,16 +48,8 @@ class FSSpecFileListerIterDataPipe(IterDataPipe[str]): e.g. host, port, username, password, etc. Example: - - .. testsetup:: - - dir_path = "path" - - .. testcode:: - - from torchdata.datapipes.iter import FSSpecFileLister - - datapipe = FSSpecFileLister(root=dir_path) + >>> from torchdata.datapipes.iter import FSSpecFileLister + >>> datapipe = FSSpecFileLister(root=dir_path, recursive=True) """ def __init__( @@ -99,10 +91,10 @@ def __iter__(self) -> Iterator[str]: if self.recursive: for current_path, dirs, files in fs.walk(path): for file_name in files: - if not match_masks(file_name, self.masks): - continue abs_path = os.path.join(current_path, file_name) if is_local else posixpath.join(current_path, file_name) + if not match_masks(abs_path, self.masks): + continue if any(file_name.startswith(protocol) for protocol in protocol_list): yield file_name @@ -112,11 +104,11 @@ def __iter__(self) -> Iterator[str]: yield abs_path else: for file_name in fs.ls(path, detail=False): # Ensure it returns List[str], not List[Dict] - if not match_masks(file_name, self.masks): - continue - abs_path = os.path.join(path, file_name) if is_local else posixpath.join(path, file_name) + if not match_masks(abs_path, self.masks): + continue + if any(file_name.startswith(protocol) for protocol in protocol_list): yield file_name elif root.startswith(tuple(protocol_list)):