Skip to content

Commit 9e88687

Browse files
vttrifonovlhoestq
andauthored
changes to MappedExamplesIterable to resolve #7345 (#7353)
* changes to MappedExamplesIterable to resolve #7345 * changed MappedExamplesIterable and added test_iterable_dataset_vs_dataset * test_iterable_dataset * test_iterable_dataset * Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * Update src/datasets/iterable_dataset.py Co-authored-by: Quentin Lhoest <[email protected]> * ruff happy --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 6457be6 commit 9e88687

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

src/datasets/iterable_dataset.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _examples_to_batch(examples: List[Dict[str, Any]]) -> Dict[str, list]:
7979

8080
def _batch_to_examples(batch: Dict[str, list]) -> Iterator[Dict[str, Any]]:
8181
"""Convert a batch (dict of examples) to examples list"""
82-
n_examples = len(batch[next(iter(batch))])
82+
n_examples = 0 if len(batch) == 0 else len(batch[next(iter(batch))])
8383
for i in range(n_examples):
8484
yield {col: array[i] for col, array in batch.items()}
8585

@@ -1044,12 +1044,16 @@ def _iter(self):
10441044
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
10451045
if self.with_indices:
10461046
function_args.append([current_idx + i for i in range(len(key_examples_list))])
1047-
transformed_batch = dict(batch) # this will be updated with the function output
1048-
transformed_batch.update(self.function(*function_args, **self.fn_kwargs))
1049-
# then remove the unwanted columns
1047+
inputs_to_merge = dict(batch)
1048+
processed_inputs = self.function(*function_args, **self.fn_kwargs)
1049+
# this logic mimics the one in Dataset.map
10501050
if self.remove_columns:
10511051
for c in self.remove_columns:
1052-
del transformed_batch[c]
1052+
if c in inputs_to_merge:
1053+
del inputs_to_merge[c]
1054+
if processed_inputs is inputs and c in processed_inputs:
1055+
del processed_inputs[c]
1056+
transformed_batch = {**inputs_to_merge, **processed_inputs}
10531057
if transformed_batch:
10541058
first_col = next(iter(transformed_batch))
10551059
bad_cols = [
@@ -1087,12 +1091,16 @@ def _iter(self):
10871091
function_args = [inputs] if self.input_columns is None else [inputs[col] for col in self.input_columns]
10881092
if self.with_indices:
10891093
function_args.append(current_idx)
1090-
transformed_example = dict(example) # this will be updated with the function output
1091-
transformed_example.update(self.function(*function_args, **self.fn_kwargs))
1092-
# then we remove the unwanted columns
1094+
processed_inputs = self.function(*function_args, **self.fn_kwargs)
1095+
inputs_to_merge = dict(example)
1096+
# this logic mimics the one in Dataset.map
10931097
if self.remove_columns:
10941098
for c in self.remove_columns:
1095-
del transformed_example[c]
1099+
if c in inputs_to_merge:
1100+
del inputs_to_merge[c]
1101+
if processed_inputs is inputs and c in processed_inputs:
1102+
del processed_inputs[c]
1103+
transformed_example = {**inputs_to_merge, **processed_inputs}
10961104
current_idx += 1
10971105
if self._state_dict:
10981106
self._state_dict["previous_state_example_idx"] += 1

tests/test_iterable_dataset.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,59 @@ def test_mapped_examples_iterable_remove_columns(n, func, batched, batch_size, r
585585
assert_load_state_dict_resumes_iteration(ex_iterable)
586586

587587

588+
# issue #7345 and PR #7353
589+
@pytest.mark.parametrize("batched", [False, True])
590+
@pytest.mark.parametrize("batch_size", [None, 2])
591+
@pytest.mark.parametrize("input_columns", [None, ["i"]])
592+
@pytest.mark.parametrize("remove_columns", [None, ["i"]])
593+
@pytest.mark.parametrize("new_output", [False, True])
594+
def test_iterable_dataset_vs_dataset_map(batched, batch_size, input_columns, remove_columns, new_output):
595+
if input_columns is not None and not new_output:
596+
return
597+
598+
ds1 = Dataset.from_list([{"i": i} for i in range(4)])
599+
600+
if batched:
601+
602+
def f1(i):
603+
return {"i": [j + 1 for j in i]}
604+
else:
605+
606+
def f1(i):
607+
return {"i": i + 1}
608+
609+
if input_columns is None:
610+
611+
def f2(x):
612+
return f1(x["i"])
613+
else:
614+
f2 = f1
615+
616+
if new_output:
617+
f = f2
618+
else:
619+
620+
def f(x):
621+
x["i"] = f2(x)["i"]
622+
return x
623+
624+
r = [
625+
list(
626+
ds2.map(
627+
f,
628+
batch_size=batch_size,
629+
batched=batched,
630+
remove_columns=remove_columns,
631+
input_columns=input_columns,
632+
)
633+
)
634+
for ds2 in [ds1, ds1.to_iterable_dataset()]
635+
]
636+
r[1] = [x for x in r[1] if len(x) > 0]
637+
assert len(r[0]) == len(r[1])
638+
assert all(x == y for x, y in zip(*r))
639+
640+
588641
@pytest.mark.parametrize(
589642
"n, func, batched, batch_size, fn_kwargs",
590643
[

0 commit comments

Comments
 (0)