From 751da998030e727c726511ed9aeacc812a912b68 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 13 May 2022 12:27:04 -0700 Subject: [PATCH 1/5] merged --- test/test_iterdatapipe.py | 14 ++++++++++++++ torchdata/datapipes/iter/__init__.py | 8 +++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8fc4cb16a..47e85a738 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -32,6 +32,7 @@ MapKeyZipper, MaxTokenBucketizer, ParagraphAggregator, + ExtractKeys, Rows2Columnar, SampleMultiplexer, UnZipper, @@ -902,6 +903,19 @@ def test_mux_longest_iterdatapipe(self): with self.assertRaises(TypeError): len(output_dp) + def test_extractor(self): + + # Functional Test: verify that extracting by patterns yields correct output + stage1 = IterableWrapper([ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + stage2 = ExtractKeys(stage1, "*.txt", "*.bin") + output = list(iter(stage2)) + assert len(output) == 2 + assert output[0][0] == "1" + assert output[0][1] == "1b" + def test_zip_longest_iterdatapipe(self): # Functional Test: raises TypeError when an input is not of type `IterDataPipe` diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index a4109cced..19d6a4b7b 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -109,7 +109,12 @@ TFRecordLoaderIterDataPipe as TFRecordLoader, ) from torchdata.datapipes.iter.util.unzipper import UnZipperIterDataPipe as UnZipper -from torchdata.datapipes.iter.util.webdataset import WebDatasetIterDataPipe as WebDataset +from torchdata.datapipes.iter.util.webdataset import ( + WebDatasetIterDataPipe as WebDataset, +) +from torchdata.datapipes.iter.util.extractkeys import ( + ExtractKeysIterDataPipe as ExtractKeys, +) from torchdata.datapipes.iter.util.xzfileloader import ( XzFileLoaderIterDataPipe as XzFileLoader, XzFileReaderIterDataPipe as XzFileReader, @@ -136,6 +141,7 @@ "Demultiplexer", "EndOnDiskCacheHolder", "Enumerator", + "ExtractKeys", "Extractor", "FSSpecFileLister", "FSSpecFileOpener", From 5dc2a8902628a5c9399b25b6f76bf0f710b6a04f Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 13 May 2022 12:37:24 -0700 Subject: [PATCH 2/5] added extractkeys --- torchdata/datapipes/iter/util/extractkeys.py | 59 ++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 torchdata/datapipes/iter/util/extractkeys.py diff --git a/torchdata/datapipes/iter/util/extractkeys.py b/torchdata/datapipes/iter/util/extractkeys.py new file mode 100644 index 000000000..bbf0cafe5 --- /dev/null +++ b/torchdata/datapipes/iter/util/extractkeys.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from fnmatch import fnmatch +from typing import Dict, Iterator, Tuple + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + + +@functional_datapipe("extract_keys") +class ExtractKeysIterDataPipe(IterDataPipe[Dict]): + r""" + Given a stream of dictionaries, return a stream of tuples by selecting keys using glob patterns. + + Args: + source_datapipe: a DataPipe yielding a stream of dictionaries. + duplicate_is_error: it is an error if the same key is selected twice (True) + ignore_missing: skip any dictionaries where one or more patterns don't match (False) + *args: list of glob patterns or list of glob patterns + + Returns: + a DataPipe yielding a stream of tuples + + Examples: + >>> dp = FileLister(...).load_from_tar().webdataset().decode(...).extract_keys(["*.jpg", "*.png"], "*.gt.txt") + """ + + def __init__( + self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False + ) -> None: + super().__init__() + self.source_datapipe: IterDataPipe[Dict] = source_datapipe + self.duplicate_is_error = duplicate_is_error + self.patterns = args + self.ignore_missing = ignore_missing + + def __iter__(self) -> Iterator[Tuple]: + for sample in self.source_datapipe: + result = [] + for pattern in self.patterns: + pattern = [pattern] if not isinstance(pattern, (list, tuple)) else pattern + matches = [x for x in sample.keys() if any(fnmatch(x, p) for p in pattern)] + if len(matches) == 0: + if self.ignore_missing: + continue + else: + raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + if len(matches) > 1 and self.duplicate_is_error: + raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + value = sample[matches[0]] + result.append(value) + yield tuple(result) + + def __len__(self) -> int: + return len(self.source_datapipe) From 59298b78519ed5e6ced0d158e6afb6e702271b9b Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 31 Aug 2022 12:51:16 -0700 Subject: [PATCH 3/5] added as_tuple option, better testing, duplicate detection --- test/test_iterdatapipe.py | 17 ++++++++++--- torchdata/datapipes/iter/util/extractkeys.py | 26 ++++++++++++++------ 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 47e85a738..563791d4e 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -910,11 +910,22 @@ def test_extractor(self): {"1.txt": "1", "1.bin": "1b"}, {"2.txt": "2", "2.bin": "2b"}, ]) + stage2 = ExtractKeys(stage1, "*.txt", "*.bin", as_tuple=True) + output = list(iter(stage2)) + self.assertEqual(output, [("1", "1b"), ("2", "2b")]) stage2 = ExtractKeys(stage1, "*.txt", "*.bin") output = list(iter(stage2)) - assert len(output) == 2 - assert output[0][0] == "1" - assert output[0][1] == "1b" + self.assertEqual(output, [ + {"1.txt": "1", "1.bin": "1b"}, + {"2.txt": "2", "2.bin": "2b"}, + ]) + with self.assertRaisesRegex(ValueError, r"(?i)multiple sample keys"): + stage2 = ExtractKeys(stage1, "*") + output = list(iter(stage2)) + with self.assertRaisesRegex(ValueError, r"selected twice"): + stage2 = ExtractKeys(stage1, "*.txt", "*t") + output = list(iter(stage2)) + def test_zip_longest_iterdatapipe(self): diff --git a/torchdata/datapipes/iter/util/extractkeys.py b/torchdata/datapipes/iter/util/extractkeys.py index bbf0cafe5..c5d0f4f87 100644 --- a/torchdata/datapipes/iter/util/extractkeys.py +++ b/torchdata/datapipes/iter/util/extractkeys.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from fnmatch import fnmatch -from typing import Dict, Iterator, Tuple +from typing import Dict, Iterator, Tuple, Union from torchdata.datapipes import functional_datapipe from torchdata.datapipes.iter import IterDataPipe @@ -30,17 +30,19 @@ class ExtractKeysIterDataPipe(IterDataPipe[Dict]): """ def __init__( - self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False + self, source_datapipe: IterDataPipe[Dict], *args, duplicate_is_error=True, ignore_missing=False, as_tuple=False ) -> None: super().__init__() self.source_datapipe: IterDataPipe[Dict] = source_datapipe self.duplicate_is_error = duplicate_is_error self.patterns = args self.ignore_missing = ignore_missing + self.as_tuple = as_tuple - def __iter__(self) -> Iterator[Tuple]: + def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: for sample in self.source_datapipe: result = [] + used = set() for pattern in self.patterns: pattern = [pattern] if not isinstance(pattern, (list, tuple)) else pattern matches = [x for x in sample.keys() if any(fnmatch(x, p) for p in pattern)] @@ -48,12 +50,22 @@ def __iter__(self) -> Iterator[Tuple]: if self.ignore_missing: continue else: - raise ValueError(f"Cannot find {pattern} in sample keys {sample.keys()}.") + raise ValueError(f"extract_keys: cannot find {pattern} in sample keys {sample.keys()}.") if len(matches) > 1 and self.duplicate_is_error: - raise ValueError(f"Multiple sample keys {sample.keys()} match {pattern}.") + raise ValueError(f"extract_keys: multiple sample keys {sample.keys()} match {pattern}.") + if matches[0] in used and self.duplicate_is_error: + raise ValueError(f"extract_keys: key {matches[0]} is selected twice.") + used.add(matches[0]) value = sample[matches[0]] - result.append(value) - yield tuple(result) + if self.as_tuple: + result.append(value) + else: + result.append((matches[0], value)) + if self.as_tuple: + result = tuple(result) + else: + result = {k: v for k, v in result} + yield result def __len__(self) -> int: return len(self.source_datapipe) From b31d72123aa86062486e964947227dfca32ce12e Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 31 Aug 2022 15:28:21 -0700 Subject: [PATCH 4/5] fixed type errors --- torchdata/datapipes/iter/util/extractkeys.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torchdata/datapipes/iter/util/extractkeys.py b/torchdata/datapipes/iter/util/extractkeys.py index c5d0f4f87..b675f1ec6 100644 --- a/torchdata/datapipes/iter/util/extractkeys.py +++ b/torchdata/datapipes/iter/util/extractkeys.py @@ -39,7 +39,7 @@ def __init__( self.ignore_missing = ignore_missing self.as_tuple = as_tuple - def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: + def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: # type: ignore for sample in self.source_datapipe: result = [] used = set() @@ -62,10 +62,9 @@ def __iter__(self) -> Union[Iterator[Tuple], Iterator[Dict]]: else: result.append((matches[0], value)) if self.as_tuple: - result = tuple(result) + yield tuple(result) else: - result = {k: v for k, v in result} - yield result + yield {k: v for k, v in result} def __len__(self) -> int: return len(self.source_datapipe) From ba9b5a47cd24d4a8e235925350a0b706f0ba50c1 Mon Sep 17 00:00:00 2001 From: Tom Date: Wed, 31 Aug 2022 17:00:02 -0700 Subject: [PATCH 5/5] improved documentation in extract_keys --- torchdata/datapipes/iter/util/extractkeys.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchdata/datapipes/iter/util/extractkeys.py b/torchdata/datapipes/iter/util/extractkeys.py index b675f1ec6..db4f93092 100644 --- a/torchdata/datapipes/iter/util/extractkeys.py +++ b/torchdata/datapipes/iter/util/extractkeys.py @@ -21,6 +21,9 @@ class ExtractKeysIterDataPipe(IterDataPipe[Dict]): duplicate_is_error: it is an error if the same key is selected twice (True) ignore_missing: skip any dictionaries where one or more patterns don't match (False) *args: list of glob patterns or list of glob patterns + duplicate_is_error: it is an error if the same key is selected twice (True) + ignore_missing: allow patterns not to match (i.e., incomplete outputs) + as_tuple: return a tuple instead of a dictionary Returns: a DataPipe yielding a stream of tuples