Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion finetrainers/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ def __init__(
batch_size: int = 1,
num_workers: int = 0,
collate_fn=None,
drop_last: bool = False,
) -> None:
super().__init__(dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn)
super().__init__(
dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn, drop_last=drop_last
)

self._dp_rank = rank
self._rank_id = f"dp_rank_{rank}"
Expand Down
8 changes: 7 additions & 1 deletion finetrainers/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import pathlib
import random
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -715,7 +716,12 @@ def __init__(

def __iter__(self):
logger.info("Starting IterableDatasetPreprocessingWrapper for the dataset")
for sample in iter(self.dataset):
worker_info = torch.utils.data.get_worker_info()
iterator = iter(self.dataset)
if worker_info is not None:
# When num_workers > 1, in a worker process, split the dataset across workers to avoid data duplication.
iterator = itertools.islice(iterator, worker_info.id, None, worker_info.num_workers)
for sample in iterator:
for column in self.drop_columns:
sample.pop(column, None)

Expand Down
13 changes: 12 additions & 1 deletion finetrainers/parallel/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,19 @@ def prepare_dataloader(
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
drop_last = False
if isinstance(dataset, torch.utils.data.IterableDataset) and num_workers > 1:
drop_last = True
logger.info(
"Using `drop_last=True` for IterableDataset with multiple workers to ensure consistent batch sizes."
)

dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
)
dataloader = self._accelerator.prepare_data_loader(dataloader)
logger.debug("AccelerateParallelBackend::prepare_dataloader completed!")
Expand Down
12 changes: 11 additions & 1 deletion finetrainers/parallel/ptd.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,17 @@ def prepare_dataloader(
else:
dp_mesh = self.get_mesh()["dp_replicate"]
dp_local_rank = dp_mesh.get_local_rank()
dataloader = DPDataLoader(dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers)

drop_last = False
if isinstance(dataset, torch.utils.data.IterableDataset) and num_workers > 1:
drop_last = True
logger.info(
"Using `drop_last=True` for IterableDataset with multiple workers to ensure consistent batch sizes."
)

dataloader = DPDataLoader(
dp_local_rank, dataset, batch_size=batch_size, num_workers=num_workers, drop_last=drop_last
)
logger.debug("PytorchDTensorParallelBackend::prepare_dataloader completed!")
return dataloader

Expand Down
71 changes: 71 additions & 0 deletions tests/data/test_iterable_dataset_multi_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import unittest

import torch

from finetrainers.data.dataset import IterableDatasetPreprocessingWrapper


class DummyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, num_samples=100):
super().__init__()
self.num_samples = num_samples

def __iter__(self):
for i in range(self.num_samples):
yield {"caption": f"caption_{i}", "image": i}


class TestIterableDatasetMultiWorker(unittest.TestCase):
def test_no_duplication_with_multiple_workers(self):
"""
Tests that IterableDatasetPreprocessingWrapper correctly shards data and
handles the drop_last logic, by directly comparing the loaded items
to a manually simulated expected set of items.
"""
num_samples = 101 # Not perfectly divisible by batch_size to test drop_last
batch_size = 4

for num_workers in range(1, 9):
with self.subTest(num_workers=num_workers):
drop_last = num_workers > 1

original_dataset = DummyIterableDataset(num_samples)
original_items = [item["image"] for item in original_dataset]

wrapped_dataset = IterableDatasetPreprocessingWrapper(
dataset=original_dataset,
dataset_type="image",
)

dataloader = torch.utils.data.DataLoader(
wrapped_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=drop_last,
)

loaded_items = [item for batch in dataloader for item in batch["image"].tolist()]

# Manually simulate the sharding and drop_last logic to get the exact expected set of items.
expected_items = []
if drop_last:
for worker_id in range(num_workers):
# 1. Simulate the interleaved sharding from itertools.islice
worker_items = original_items[worker_id::num_workers]
# 2. Simulate the drop_last logic for this worker's items
num_full_batches = len(worker_items) // batch_size
items_to_keep_for_worker = worker_items[: num_full_batches * batch_size]
expected_items.extend(items_to_keep_for_worker)
else: # This case is for num_workers == 1
expected_items = original_items

# Ensure no duplicates were loaded.
self.assertEqual(
len(loaded_items),
len(expected_items),
f"The number of loaded items does not match the expected number for {num_workers} workers.",
)


if __name__ == "__main__":
unittest.main()