diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 2a941315b..12667d273 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -21,6 +21,7 @@ from torchdata.datapipes.iter import ( BucketBatcher, Cycler, + ExtractKeys, Header, IndexAdder, InMemoryCacheHolder, @@ -951,6 +952,30 @@ def test_mux_longest_iterdatapipe(self): with self.assertRaises(TypeError): len(output_dp) + def test_extractor(self): + + # Functional Test: verify that extracting by patterns yields correct output + stage1 = IterableWrapper([ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + stage2 = ExtractKeys(stage1, "*.txt", "*.bin", as_tuple=True) + output = list(iter(stage2)) + self.assertEqual(output, [("1", "1b"), ("2", "2b")]) + stage2 = ExtractKeys(stage1, "*.txt", "*.bin") + output = list(iter(stage2)) + self.assertEqual(output, [ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + with self.assertRaisesRegex(ValueError, r"(?i)multiple sample keys"): + stage2 = ExtractKeys(stage1, "*") + output = list(iter(stage2)) + with self.assertRaisesRegex(ValueError, r"selected twice"): + stage2 = ExtractKeys(stage1, "*.txt", "*t") + output = list(iter(stage2)) + + def test_zip_longest_iterdatapipe(self): # Functional Test: raises TypeError when an input is not of type `IterDataPipe` diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 4a2265d65..477a6ed3b 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -121,7 +121,12 @@ TFRecordLoaderIterDataPipe as TFRecordLoader, ) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset +from torchdata.datapipes.iter.util.webdataset import ( + WebDatasetIterDataPipe as WebDataset, +) +from torchdata.datapipes.iter.util.extractkeys import ( + ExtractKeysIterDataPipe as ExtractKeys, +) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, XzFileReaderIterDataPipe as XzFileReader, @@ -151,6 +156,7 @@ "Dropper", "EndOnDiskCacheHolder", "Enumerator", + "ExtractKeys", "Extractor", "FSSpecFileLister", "FSSpecFileOpener", diff --git a/torchdata/datapipes/iter/util/extractkeys.py b/torchdata/datapipes/iter/util/extractkeys.py new file mode 100644 index 000000000..db4f93092 --- /dev/null +++ b/torchdata/datapipes/iter/util/extractkeys.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from fnmatch import fnmatch +from typing import Dict, Iterator, Tuple, Union + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + + +@functional_datapipe("extract_keys") +class ExtractKeysIterDataPipe(IterDataPipe[Dict]): + r""" + Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns. + + Args: + source_datapipe: a DataPipe yielding a stream of dictionaries. + duplicate_is_error: it is an error if the same key is selected twice (True) + ignore_missing: skip any dictionaries where one or more patterns don't match (False) + *args: list of glob patterns or list of glob patterns + duplicate_is_error: it is an error if the same key is selected twice (True) + ignore_missing: allow patterns not to match (i.e., incomplete outputs) + as_tuple: return a tuple instead of a dictionary + + Returns: + a DataPipe yielding a stream of tuples + + Examples: + >>> dp = FileLister(...).load_from_tar().webdataset().decode(...).extract_keys(["*.jpg", "*.png"], "*.gt.txt") + """ + + def __init__( + self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False, as_tuple=False + ) -> None: + super().__init__() + self.source_datapipe: IterDataPipe[Dict] = source_datapipe + self.duplicate_is_error = duplicate_is_error + self.patterns = args + self.ignore_missing = ignore_missing + self.as_tuple = as_tuple + + def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: # type: ignore + for sample in self.source_datapipe: + result = [] + used = set() + for pattern in self.patterns: + pattern = [pattern] if not isinstance(pattern, (list, tuple)) else pattern + matches = [x for x in sample.keys() if any(fnmatch(x, p) for p in pattern)] + if len(matches) == 0: + if self.ignore_missing: + continue + else: + raise ValueError(f"extract_keys: cannot find {pattern} in sample keys {sample.keys()}.") + if len(matches) > 1 and self.duplicate_is_error: + raise ValueError(f"extract_keys: multiple sample keys {sample.keys()} match {pattern}.") + if matches[0] in used and self.duplicate_is_error: + raise ValueError(f"extract_keys: key {matches[0]} is selected twice.") + used.add(matches[0]) + value = sample[matches[0]] + if self.as_tuple: + result.append(value) + else: + result.append((matches[0], value)) + if self.as_tuple: + yield tuple(result) + else: + yield {k: v for k, v in result} + + def __len__(self) -> int: + return len(self.source_datapipe)