Skip to content

Commit cd3ce34

Browse files
authored
Fix when map function modifies input inplace (#4174)
1 parent e4bfe27 commit cd3ce34

File tree

2 files changed

+22
-1
lines changed

2 files changed

+22
-1
lines changed

src/datasets/arrow_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2234,7 +2234,9 @@ def apply_function_on_filtered_inputs(inputs, indices, check_same_num_examples=F
22342234
)
22352235
if remove_columns is not None:
22362236
for column in remove_columns:
2237-
inputs.pop(column)
2237+
# `function` can modify input in-place causing column to be already removed.
2238+
if column in inputs:
2239+
inputs.pop(column)
22382240
if check_same_num_examples:
22392241
input_num_examples = len(inputs[next(iter(inputs.keys()))])
22402242
processed_inputs_num_examples = len(processed_inputs[next(iter(processed_inputs.keys()))])

tests/test_arrow_dataset.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,25 @@ def map_batched_with_indices(example, idx):
11391139
)
11401140
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_with_indices_batched)
11411141

1142+
# check remove columns for even if the function modifies input in-place
1143+
def map_batched_modifying_inputs_inplace(example):
1144+
result = {"filename_new": [x + "_extension" for x in example["filename"]]}
1145+
del example["filename"]
1146+
return result
1147+
1148+
with tempfile.TemporaryDirectory() as tmp_dir:
1149+
with self._create_dummy_dataset(in_memory, tmp_dir) as dset:
1150+
with dset.map(
1151+
map_batched_modifying_inputs_inplace, batched=True, remove_columns="filename"
1152+
) as dset_test_modifying_inputs_inplace:
1153+
self.assertEqual(len(dset_test_modifying_inputs_inplace), 30)
1154+
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
1155+
self.assertDictEqual(
1156+
dset_test_modifying_inputs_inplace.features,
1157+
Features({"filename_new": Value("string")}),
1158+
)
1159+
assert_arrow_metadata_are_synced_with_dataset_features(dset_test_modifying_inputs_inplace)
1160+
11421161
def test_map_nested(self, in_memory):
11431162
with tempfile.TemporaryDirectory() as tmp_dir:
11441163
with Dataset.from_dict({"field": ["a", "b"]}) as dset:

0 commit comments

Comments
 (0)