Skip to content

Commit fa2020e

Browse files
feat(streaming): enable per-dataset batch-sizes in CombinedStreamingDataset (Lightning-AI#635)
* feat(streaming): per-dataset batch-size support in CombinedStreamingDataset (Lightning-AI#327) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat(streaming): add per-dataset batch-size support and fix mypy issues * fix(streaming): always switch dataset once per-stream quota is met * chore(typing): align batch_size annotation with Union[int, Sequence[int]] * fix(typing): ensure int batch_size passed to get_len for mypy * chore(typing): remove redundant casts flagged by mypy * style(ruff): replace typing.List/Dict with built-in generics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fc59c8a commit fa2020e

File tree

4 files changed

+147
-17
lines changed

4 files changed

+147
-17
lines changed

src/litdata/streaming/combined.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import random
1616
from collections.abc import Iterator, Sequence
1717
from copy import deepcopy
18-
from typing import Any, Literal, Optional
18+
from typing import Any, Literal, Optional, Union
1919

2020
from litdata.debugger import ChromeTraceColors, _get_log_msg
2121
from litdata.streaming.dataset import StreamingDataset
@@ -170,7 +170,7 @@ def __init__(
170170
weights: Sequence[Optional[float]],
171171
use_streaming_dataloader: bool,
172172
num_samples_yielded: Any,
173-
batch_size: int,
173+
batch_size: Union[int, Sequence[int]],
174174
batching_method: BatchingMethodType,
175175
iterate_over_all: bool = False,
176176
) -> None:
@@ -183,7 +183,14 @@ def __init__(
183183
self._rng = random.Random(seed) # noqa: S311
184184
self._iterate_over_all = iterate_over_all
185185
self._batching_method = batching_method
186+
# Batch size can be an int (applied to all datasets) or a sequence providing
187+
# a specific batch size per dataset.
186188
self._batch_size = batch_size
189+
from collections.abc import Sequence as _Sequence
190+
191+
# Validate when a sequence is provided
192+
if isinstance(batch_size, _Sequence) and len(batch_size) != len(datasets):
193+
raise ValueError("When providing a sequence of batch sizes, its length must match the number of datasets.")
187194
self._is_done = False
188195

189196
if num_samples_yielded is not None:
@@ -196,9 +203,10 @@ def __init__(
196203
self._use_streaming_dataloader = use_streaming_dataloader
197204
self._is_done = False
198205

199-
# Used to track the number of samples yielded in the current batch
200-
# and the current dataset index
201-
# This is used only when batching_method is set to "per_stream"
206+
# Track the number of samples yielded in the current (DataLoader) batch
207+
# and the active dataset index. This is used only when batching_method is
208+
# set to "per_stream". With per-dataset batch sizes we look up the limit
209+
# dynamically based on ``self._batch_size`` if it is a sequence.
202210
self._samples_yielded_in_batch = 0
203211
self._cur_dataset_index = -1
204212

@@ -240,11 +248,35 @@ def _get_dataset_index(self) -> int:
240248
# For every sample, randomly select a dataset (weighted)
241249
dataset_idx = self._set_new_dataset_index()
242250
elif self._batching_method == BatchingMethod.PER_STREAM:
243-
# For each batch, pick a dataset and stick with it for the whole batch
244-
if self._cur_dataset_index == -1 or self._samples_yielded_in_batch >= self._batch_size:
251+
# For each batch, pick a dataset and stick with it until the
252+
# desired number of samples for that dataset have been yielded.
253+
254+
from collections.abc import Sequence as _Sequence
255+
256+
if self._cur_dataset_index == -1:
257+
# Start of iteration or after switching dataset
245258
self._cur_dataset_index = self._set_new_dataset_index()
246259
self._samples_yielded_in_batch = 0
260+
247261
dataset_idx = self._cur_dataset_index
262+
263+
# Determine the batch-size limit for the current dataset
264+
limit = self._batch_size[dataset_idx] if isinstance(self._batch_size, _Sequence) else self._batch_size
265+
266+
if self._samples_yielded_in_batch >= limit:
267+
# Current dataset reached its quota; pick a *different* dataset if possible
268+
candidate_idx = self._cur_dataset_index
269+
if len([i for i in self._dataset_indexes if i is not None]) > 1:
270+
while candidate_idx == self._cur_dataset_index:
271+
candidate_idx = self._set_new_dataset_index()
272+
# Update tracking
273+
self._cur_dataset_index = candidate_idx
274+
self._samples_yielded_in_batch = 0
275+
dataset_idx = self._cur_dataset_index
276+
# Re-compute limit for the new dataset
277+
if isinstance(self._batch_size, _Sequence):
278+
limit = self._batch_size[dataset_idx]
279+
248280
self._samples_yielded_in_batch += 1
249281
else:
250282
raise ValueError(f"Invalid batching method: {self._batching_method}")

src/litdata/streaming/parallel.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,13 @@ def __iter__(self) -> Iterator[Any]:
250250
return self._iterator
251251

252252
def __len__(self) -> Optional[int]:
253-
return self.get_len(self.num_workers, self.batch_size if self.batch_size else 1)
253+
# ``batch_size`` may be a sequence when per-dataset values were set on
254+
# the wrapper. For length estimation we only need a scalar; we take
255+
# the first element if a sequence is provided.
256+
from collections.abc import Sequence
257+
258+
bs_int: int = int(self.batch_size[0]) if isinstance(self.batch_size, Sequence) else int(self.batch_size)
259+
return self.get_len(self.num_workers, bs_int if bs_int else 1)
254260

255261
def get_num_samples_yielded(
256262
self,

src/litdata/utilities/base.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# limitations under the License.
1313

1414
from abc import ABC, abstractmethod
15-
from collections.abc import Iterator
16-
from typing import Any, Optional
15+
from collections.abc import Iterator, Sequence
16+
from typing import Any, Optional, Union
1717

1818
from torch.utils.data import IterableDataset
1919

@@ -30,7 +30,7 @@ class _BaseStreamingDatasetWrapper(IterableDataset, ABC):
3030

3131
_datasets: list[StreamingDataset]
3232
_current_epoch: int
33-
batch_size: int
33+
batch_size: Union[int, Sequence[int]]
3434
num_workers: int
3535
_force_override_state_dict: bool
3636
_use_streaming_dataloader: bool
@@ -41,11 +41,31 @@ def set_shuffle(self, shuffle: bool) -> None:
4141
for dataset in self._datasets:
4242
dataset.set_shuffle(shuffle)
4343

44-
def set_batch_size(self, batch_size: int) -> None:
45-
"""Set the current batch size to the datasets."""
46-
self.batch_size = batch_size
47-
for dataset in self._datasets:
48-
dataset.set_batch_size(batch_size)
44+
def set_batch_size(self, batch_size: Union[int, Sequence[int]]) -> None:
45+
"""Set the current batch size.
46+
47+
This method now supports either:
48+
49+
1. a single ``int`` applied to all wrapped datasets (previous behaviour), or
50+
2. a ``Sequence[int]`` that specifies one batch size per wrapped dataset.
51+
52+
The length of the sequence must match the number of wrapped datasets.
53+
"""
54+
# Defer the import to avoid overhead when not required
55+
from collections.abc import Sequence
56+
57+
self.batch_size = batch_size # store as-is for later access
58+
59+
if isinstance(batch_size, Sequence):
60+
if len(batch_size) != len(self._datasets):
61+
raise ValueError(
62+
"The length of `batch_size` must match the number of datasets when passing a sequence."
63+
)
64+
for bs, dataset in zip(batch_size, self._datasets):
65+
dataset.set_batch_size(bs)
66+
else:
67+
for dataset in self._datasets:
68+
dataset.set_batch_size(int(batch_size))
4969

5070
def set_num_workers(self, num_workers: int) -> None:
5171
"""Set the current number of workers to the datasets."""
@@ -97,8 +117,20 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
97117
self._num_samples_yielded = state_dict["num_samples_yielded"]
98118

99119
def _get_len(self, d: Any) -> int:
120+
# mypy: ``self.batch_size`` can be a ``Sequence[int]`` now, but the
121+
# underlying datasets still expect a plain ``int`` for their
122+
# ``get_len`` signature. We pass an `int` in both cases and use the
123+
# first element of the sequence when a per-dataset list is provided.
124+
125+
from collections.abc import Sequence
126+
127+
if isinstance(self.batch_size, Sequence):
128+
bs_int: int = int(self.batch_size[0] if self.batch_size else 1)
129+
else:
130+
bs_int = int(self.batch_size)
131+
100132
if isinstance(d, StreamingDataset):
101-
return d.get_len(self.num_workers, self.batch_size)
133+
return d.get_len(self.num_workers, bs_int)
102134
return len(d)
103135

104136
@abstractmethod

tests/streaming/test_combined.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,3 +596,63 @@ def test_combined_dataset_dataloader_states_partial_iterations(combined_dataset,
596596
assert dataloader.current_epoch == 2, "Current epoch should be 2 in the second iteration"
597597
samples_yielded += len(batch)
598598
assert samples_yielded == len(combined_dataset), "All samples should be yielded in the second epoch."
599+
600+
601+
# -----------------------------------------------------------------------------
602+
# New tests: per-dataset batch sizes with batching_method="per_stream"
603+
# -----------------------------------------------------------------------------
604+
605+
606+
@pytest.mark.parametrize("batch_sizes", [[1, 2], [2, 3]])
607+
def test_combined_dataset_per_dataset_batch_size(batch_sizes):
608+
"""Validate that when individual batch sizes are provided for each inner dataset.
609+
610+
The iterator respects these limits when *batching_method='per_stream'*.
611+
"""
612+
# Build two trivial iterable datasets that produce easily distinguishable values
613+
dataset1 = SimpleDataset(0, 200) # dataset 0 values 0-199
614+
dataset2 = SimpleDataset(1000, 1200) # dataset 1 values 1000-1199
615+
616+
cds = TestCombinedStreamingDataset(
617+
datasets=[dataset1, dataset2],
618+
weights=[0.5, 0.5],
619+
batching_method="per_stream",
620+
iterate_over_all=False,
621+
seed=123,
622+
)
623+
624+
# Apply the per-dataset batch sizes
625+
cds.set_batch_size(batch_sizes)
626+
627+
# Iterate a reasonable number of samples to observe several switches
628+
num_samples = 300
629+
iterator = iter(cds)
630+
631+
# Helper to map value -> dataset index
632+
def get_ds_id(val):
633+
return 0 if val < 1000 else 1
634+
635+
current_ds = None
636+
run_length = 0
637+
638+
for _ in range(num_samples):
639+
val = next(iterator)
640+
ds_id = get_ds_id(val)
641+
642+
if current_ds is None:
643+
# first sample
644+
current_ds = ds_id
645+
run_length = 1
646+
elif ds_id == current_ds:
647+
run_length += 1
648+
else:
649+
# dataset switch – verify previous run respected its quota
650+
assert run_length <= batch_sizes[current_ds], (
651+
f"Dataset {current_ds} emitted {run_length} consecutive samples (limit {batch_sizes[current_ds]})"
652+
)
653+
current_ds = ds_id
654+
run_length = 1
655+
656+
# Final run check at loop end
657+
if current_ds is not None:
658+
assert run_length <= batch_sizes[current_ds]

0 commit comments

Comments
 (0)