Skip to content

[data] feat: add DynamicBatchingSizeDataset for stateful multi-worker dynamic batching#488

Open
LiuzcEECS wants to merge 3 commits intomainfrom
zhichao/dynamic-batching-dataset
Open

[data] feat: add DynamicBatchingSizeDataset for stateful multi-worker dynamic batching#488
LiuzcEECS wants to merge 3 commits intomainfrom
zhichao/dynamic-batching-dataset

Conversation

@LiuzcEECS
Copy link
Collaborator

What does this PR do?

Add DynamicBatchingSizeDataset, a new dataset-level dynamic batching approach that constructs micro batches inside DataLoader worker processes (as opposed to the existing DynamicBatchSizeDataLoader which does it in the main process). This enables proper checkpoint/resume support via StatefulDataLoader's state_dict()/load_state_dict() for dynamic batching, and is controlled by the new dyn_bsz_in_worker_loop flag.

Test

All 9 tests pass (pytest tests/data/test_dynamic_batching_dataset.py):

  • 5 unit tests: basic dynamic batching (shuffle on/off), force long sequence, last batch draining, without get_item
  • 4 distributed tests via torchrun --nproc_per_node=2: shuffle × save_by_idx (2×2 combinations), each running 2 epochs of training with checkpoint save at step 2 of epoch 1 and resume verification

API and Usage Example

New arguments:

# Use dataset-level dynamic batching (DynamicBatchingSizeDataset) instead of main-process batching (DynamicBatchSizeDataLoader)
--train.dyn_bsz_in_worker_loop=false

# Control whether to save buffer by index (way smaller checkpoint) or by full sample
--train.dyn_bsz_dataset_save_by_idx=true

When dyn_bsz_in_worker_loop=false, build_native_dataloader wraps the dataset with DynamicBatchingSizeDataset, which yields pre-batched micro batches from each worker.

When dyn_bsz_in_worker_loop=true (default, existing behavior), the original DynamicBatchSizeDataLoader path is used — no behavior change.

Design & Code Changes

Core: DynamicBatchingSizeDataset (veomni/data/dynamic_batching.py)

  • New IterableDataset subclass that buffers samples in each worker and yields micro batches when the buffer has ≥ ready_for_micro_batch_threshold samples and ≥ micro_batch_seq_length total tokens.
  • Greedy bin-packing in _get_micro_batch(): iterates buffer, selects samples fitting within the token budget.
  • Supports state_dict() / load_state_dict() for StatefulDataLoader checkpoint/resume:
    • save_by_idx=True: saves only sample indices (smaller checkpoint), requires dataset to support get_item() and output_refetch_idx.
    • save_by_idx=False: saves full buffer contents via deepcopy.

Integration (veomni/data/data_loader.py)

  • build_native_dataloader now branches on dyn_bsz_in_worker_loop:
    • True: existing TextBatchingStrategy + DynamicBatchSizeDataLoader path.
    • False: wraps dataset in DynamicBatchingSizeDataset, uses NoopDataCollator, returns StatefulDataLoader directly (no DynamicBatchSizeDataLoader wrapper).

Arguments (veomni/arguments/arguments_types.py)

  • dyn_bsz_in_worker_loop: bool = True — controls which dynamic batching path to use.
  • dyn_bsz_dataset_save_by_idx: bool = True — controls checkpoint buffer serialization strategy.
  • Updated dataloader_batch_size calculation for dyn_bsz_in_worker_loop=False case.

Minor refactor (veomni/data/batching_strategy.py)

  • Renamed BaseBatchingStrategy.is_full_filled()is_ready_for_micro_batch() for clarity.

Tests (tests/data/test_dynamic_batching_dataset.py, tests/data/utils.py)

  • DummyMappingDataset / DummyIterableDataset: test fixtures with configurable sequence lengths, rank/worker sharding (like StreamingDataset in ByteDance), and state_dict() support.
  • Unit tests covering basic batching, long sequence handling, dataset exhaustion.
  • Distributed tests running multi-epoch training with checkpoint save/resume verification across shuffle × save_by_idx combinations.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces DynamicBatchingSizeDataset, a new approach for dynamic batching that operates within DataLoader worker processes. This is a significant feature that enables stateful checkpointing and resumption for dynamic batching with multiple workers. The changes are extensive, including the new dataset implementation, new arguments to control it, integration into the data loader build process, and a comprehensive new test suite with both unit and distributed tests.

Overall, the implementation is solid and the tests are thorough. However, I've identified two critical issues. One is a potential breaking change in argument handling that could affect existing user configurations for dynamic batching. The other is a bug in the new DynamicBatchingSizeDataset's iterator implementation that would prevent correct multi-epoch training. Addressing these issues is crucial for the stability and correctness of this new feature.

Comment on lines +117 to +118
if not self._data_iter:
self._data_iter = iter(self.dataset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The __iter__ method does not correctly reset the underlying data iterator (self._data_iter) for new epochs. It's initialized only if it's None. After the first epoch, self._data_iter will be an exhausted iterator. When a new epoch starts and __iter__ is called again, it will reuse the exhausted iterator, causing subsequent epochs to yield no data or incomplete data. This breaks multi-epoch training.

The iterator from the upstream dataset should be re-initialized at the beginning of every __iter__ call to ensure each epoch processes the data correctly from the start.

        self._data_iter = iter(self.dataset)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are wrong, in the pytorch dataloader, when iter is started, it will copy the state of the dataset inited in the main process, in the main process, we only init it and the _data_iter is always None in the main process. Then the iter func will try to assign value to it in the beginning of looping so we should be fine.

@Luosuu
Copy link
Collaborator

Luosuu commented Feb 18, 2026

should we simply replace the original one in this case? Why would we need two if their difference is just where to perform batching

@LiuzcEECS LiuzcEECS self-assigned this Feb 18, 2026
Copy link
Collaborator

@Luosuu Luosuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a feature switch to hide this first for users. Let's test this first

@LiuzcEECS LiuzcEECS force-pushed the zhichao/dynamic-batching-dataset branch from 7cb0d4b to 469a99a Compare February 19, 2026 01:59
@LiuzcEECS LiuzcEECS force-pushed the zhichao/dynamic-batching-dataset branch from 469a99a to 4dc4bbb Compare February 19, 2026 04:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments