diff --git a/finetrainers/data/dataloader.py b/finetrainers/data/dataloader.py index 75222948..e9879585 100644 --- a/finetrainers/data/dataloader.py +++ b/finetrainers/data/dataloader.py @@ -18,8 +18,11 @@ def __init__( batch_size: int = 1, num_workers: int = 0, collate_fn=None, + drop_last: bool = False, ) -> None: - super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn) + super().__init__( + dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, drop_last=drop_last + ) self._dp_rank = rank self._rank_id = f"dp_rank_{rank}" diff --git a/finetrainers/data/dataset.py b/finetrainers/data/dataset.py index 5416e615..469d7020 100644 --- a/finetrainers/data/dataset.py +++ b/finetrainers/data/dataset.py @@ -1,3 +1,4 @@ +import itertools import pathlib import random from typing import Any, Dict, List, Optional, Tuple, Union @@ -715,7 +716,12 @@ def __init__( def __iter__(self): logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset") - for sample in iter(self.dataset): + worker_info = torch.utils.data.get_worker_info() + iterator = iter(self.dataset) + if worker_info is not None: + # When num_workers > 1, in a worker process, split the dataset across workers to avoid data duplication. + iterator = itertools.islice(iterator, worker_info.id, None, worker_info.num_workers) + for sample in iterator: for column in self.drop_columns: sample.pop(column, None) diff --git a/finetrainers/parallel/accelerate.py b/finetrainers/parallel/accelerate.py index 59c1b5e4..fa8a72da 100644 --- a/finetrainers/parallel/accelerate.py +++ b/finetrainers/parallel/accelerate.py @@ -136,8 +136,19 @@ def prepare_dataloader( num_workers: int = 0, pin_memory: bool = False, ) -> DataLoader: + drop_last = False + if isinstance(dataset, torch.utils.data.IterableDataset) and num_workers > 1: + drop_last = True + logger.info( + "Using `drop_last=True` for IterableDataset with multiple workers to ensure consistent batch sizes." + ) + dataloader = torch.utils.data.DataLoader( - dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory + dataset, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, ) dataloader = self._accelerator.prepare_data_loader(dataloader) logger.debug("AccelerateParallelBackend::prepare_dataloader completed!") diff --git a/finetrainers/parallel/ptd.py b/finetrainers/parallel/ptd.py index 2a95b1a9..2967849b 100644 --- a/finetrainers/parallel/ptd.py +++ b/finetrainers/parallel/ptd.py @@ -150,7 +150,17 @@ def prepare_dataloader( else: dp_mesh = self.get_mesh()["dp_replicate"] dp_local_rank = dp_mesh.get_local_rank() - dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers) + + drop_last = False + if isinstance(dataset, torch.utils.data.IterableDataset) and num_workers > 1: + drop_last = True + logger.info( + "Using `drop_last=True` for IterableDataset with multiple workers to ensure consistent batch sizes." + ) + + dataloader = DPDataLoader( + dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers, drop_last=drop_last + ) logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!") return dataloader diff --git a/tests/data/test_iterable_dataset_multi_workers.py b/tests/data/test_iterable_dataset_multi_workers.py new file mode 100644 index 00000000..89629b3b --- /dev/null +++ b/tests/data/test_iterable_dataset_multi_workers.py @@ -0,0 +1,71 @@ +import unittest + +import torch + +from finetrainers.data.dataset import IterableDatasetPreprocessingWrapper + + +class DummyIterableDataset(torch.utils.data.IterableDataset): + def __init__(self, num_samples=100): + super().__init__() + self.num_samples = num_samples + + def __iter__(self): + for i in range(self.num_samples): + yield {"caption": f"caption_{i}", "image": i} + + +class TestIterableDatasetMultiWorker(unittest.TestCase): + def test_no_duplication_with_multiple_workers(self): + """ + Tests that IterableDatasetPreprocessingWrapper correctly shards data and + handles the drop_last logic, by directly comparing the loaded items + to a manually simulated expected set of items. + """ + num_samples = 101 # Not perfectly divisible by batch_size to test drop_last + batch_size = 4 + + for num_workers in range(1, 9): + with self.subTest(num_workers=num_workers): + drop_last = num_workers > 1 + + original_dataset = DummyIterableDataset(num_samples) + original_items = [item["image"] for item in original_dataset] + + wrapped_dataset = IterableDatasetPreprocessingWrapper( + dataset=original_dataset, + dataset_type="image", + ) + + dataloader = torch.utils.data.DataLoader( + wrapped_dataset, + batch_size=batch_size, + num_workers=num_workers, + drop_last=drop_last, + ) + + loaded_items = [item for batch in dataloader for item in batch["image"].tolist()] + + # Manually simulate the sharding and drop_last logic to get the exact expected set of items. + expected_items = [] + if drop_last: + for worker_id in range(num_workers): + # 1. Simulate the interleaved sharding from itertools.islice + worker_items = original_items[worker_id::num_workers] + # 2. Simulate the drop_last logic for this worker's items + num_full_batches = len(worker_items) // batch_size + items_to_keep_for_worker = worker_items[: num_full_batches * batch_size] + expected_items.extend(items_to_keep_for_worker) + else: # This case is for num_workers == 1 + expected_items = original_items + + # Ensure no duplicates were loaded. + self.assertEqual( + len(loaded_items), + len(expected_items), + f"The number of loaded items does not match the expected number for {num_workers} workers.", + ) + + +if __name__ == "__main__": + unittest.main()