Skip to content

Commit 6e4a9a5

Browse files
Fix eval thread fork bomb (#39717)
1 parent 98a3c49 commit 6e4a9a5

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

src/transformers/trainer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,17 +1034,16 @@ def _get_dataloader(
10341034
seed_worker, num_workers=self.args.dataloader_num_workers, rank=self.args.process_index
10351035
)
10361036

1037-
dataloader = DataLoader(dataset, **dataloader_params)
1037+
dataloader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
10381038

1039-
# Accelerator.free_memory() will destroy the references, so
1040-
# we need to store the non-prepared version for eval dataloaders.
1039+
# Store the prepared dataloader for subsequent evaluations if using persistent workers.
10411040
if dataloader_key is not None and self.args.dataloader_persistent_workers:
10421041
if hasattr(self, "_eval_dataloaders"):
10431042
self._eval_dataloaders[dataloader_key] = dataloader
10441043
else:
10451044
self._eval_dataloaders = {dataloader_key: dataloader}
10461045

1047-
return self.accelerator.prepare(dataloader)
1046+
return dataloader
10481047

10491048
def get_train_dataloader(self) -> DataLoader:
10501049
"""
@@ -1132,7 +1131,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Union[str, Dataset]] = None
11321131
and dataloader_key in self._eval_dataloaders
11331132
and self.args.dataloader_persistent_workers
11341133
):
1135-
return self.accelerator.prepare(self._eval_dataloaders[dataloader_key])
1134+
return self._eval_dataloaders[dataloader_key]
11361135

11371136
eval_dataset = (
11381137
self.eval_dataset[eval_dataset]

0 commit comments

Comments
 (0)