-
Notifications
You must be signed in to change notification settings - Fork 169
[DataPipe] key renamer #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
MapKeyZipper, | ||
MaxTokenBucketizer, | ||
ParagraphAggregator, | ||
RenameKeys, | ||
Rows2Columnar, | ||
SampleMultiplexer, | ||
UnZipper, | ||
|
@@ -902,6 +903,18 @@ def test_mux_longest_iterdatapipe(self): | |
with self.assertRaises(TypeError): | ||
len(output_dp) | ||
|
||
def test_renamer(self): | ||
|
||
# Functional Test: verify that renaming by patterns yields correct output | ||
stage1 = IterableWrapper([ | ||
{"1.txt": "1", "1.bin": "1b"}, | ||
{"2.txt": "2", "2.bin": "2b"}, | ||
]) | ||
stage2 = RenameKeys(stage1, t="*.txt", b="*.bin") | ||
output = list(iter(stage2)) | ||
assert len(output) == 2 | ||
assert set(output[0].keys()) == set(["t", "b"]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please test other boolean flags. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
def test_zip_longest_iterdatapipe(self): | ||
|
||
# Functional Test: raises TypeError when an input is not of type `IterDataPipe` | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,79 @@ | ||||||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||||||
# All rights reserved. | ||||||
# | ||||||
# This source code is licensed under the BSD-style license found in the | ||||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
import re | ||||||
from fnmatch import fnmatch | ||||||
from typing import Dict, Iterator, List, Union | ||||||
|
||||||
from torchdata.datapipes import functional_datapipe | ||||||
from torchdata.datapipes.iter import IterDataPipe | ||||||
|
||||||
|
||||||
@functional_datapipe("rename_keys") | ||||||
class RenameKeysIterDataPipe(IterDataPipe[Dict]): | ||||||
r""" | ||||||
Given a stream of dictionaries, rename keys using glob patterns. | ||||||
Args: | ||||||
source_datapipe: a DataPipe yielding a stream of dictionaries. | ||||||
keep_unselected: keep keys/value pairs even if they don't match any pattern (False) | ||||||
must_match: all key value pairs must match (True) | ||||||
duplicate_is_error: it is an error if two renamings yield the same key (True) | ||||||
Comment on lines
+27
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Should we move these after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
*args: `(renamed, pattern)` pairs | ||||||
**kw: `renamed=pattern` pairs | ||||||
Returns: | ||||||
a DataPipe yielding a stream of dictionaries. | ||||||
Examples: | ||||||
>>> dp = IterableWrapper([{"/a/b.jpg": b"data"}]).rename_keys(image="*.jpg") | ||||||
""" | ||||||
tmbdev marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
|
||||||
def __init__( | ||||||
self, | ||||||
source_datapipe: IterDataPipe[List[Union[Dict, List]]], | ||||||
|
||||||
*args, | ||||||
keep_unselected=False, | ||||||
must_match=True, | ||||||
duplicate_is_error=True, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
nit: might be a better name but feel free to ignore |
||||||
**kw, | ||||||
) -> None: | ||||||
super().__init__() | ||||||
self.source_datapipe: IterDataPipe[List[Union[Dict, List]]] = source_datapipe | ||||||
tmbdev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
self.must_match = must_match | ||||||
self.keep_unselected = keep_unselected | ||||||
self.duplicate_is_error = duplicate_is_error | ||||||
self.renamings = [(pattern, output) for output, pattern in args] | ||||||
self.renamings += [(pattern, output) for output, pattern in kw.items()] | ||||||
|
||||||
def __iter__(self) -> Iterator[Dict]: | ||||||
for sample in self.source_datapipe: | ||||||
new_sample = {} | ||||||
matched = {k: False for k, _ in self.renamings} | ||||||
for path, value in sample.items(): | ||||||
fname = re.sub(r".*/", "", path) | ||||||
tmbdev marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
new_name = None | ||||||
for pattern, name in self.renamings[::-1]: | ||||||
if fnmatch(fname.lower(), pattern): | ||||||
matched[pattern] = True | ||||||
new_name = name | ||||||
break | ||||||
if new_name is None: | ||||||
if self.keep_unselected: | ||||||
new_sample[path] = value | ||||||
continue | ||||||
if new_name in new_sample: | ||||||
if self.duplicate_is_error: | ||||||
raise ValueError(f"Duplicate value in sample {sample.keys()} after rename.") | ||||||
continue | ||||||
new_sample[new_name] = value | ||||||
if self.must_match and not all(matched.values()): | ||||||
raise ValueError(f"Not all patterns ({matched}) matched sample keys ({sample.keys()}).") | ||||||
|
||||||
yield new_sample | ||||||
|
||||||
def __len__(self) -> int: | ||||||
return len(self.source_datapipe) |
Uh oh!
There was an error while loading. Please reload this page.