[data] feat: add DynamicBatchingSizeDataset for stateful multi-worker dynamic batching#488
[data] feat: add DynamicBatchingSizeDataset for stateful multi-worker dynamic batching#488
Conversation
|
|
There was a problem hiding this comment.
Code Review
This pull request introduces DynamicBatchingSizeDataset, a new approach for dynamic batching that operates within DataLoader worker processes. This is a significant feature that enables stateful checkpointing and resumption for dynamic batching with multiple workers. The changes are extensive, including the new dataset implementation, new arguments to control it, integration into the data loader build process, and a comprehensive new test suite with both unit and distributed tests.
Overall, the implementation is solid and the tests are thorough. However, I've identified two critical issues. One is a potential breaking change in argument handling that could affect existing user configurations for dynamic batching. The other is a bug in the new DynamicBatchingSizeDataset's iterator implementation that would prevent correct multi-epoch training. Addressing these issues is crucial for the stability and correctness of this new feature.
| if not self._data_iter: | ||
| self._data_iter = iter(self.dataset) |
There was a problem hiding this comment.
The __iter__ method does not correctly reset the underlying data iterator (self._data_iter) for new epochs. It's initialized only if it's None. After the first epoch, self._data_iter will be an exhausted iterator. When a new epoch starts and __iter__ is called again, it will reuse the exhausted iterator, causing subsequent epochs to yield no data or incomplete data. This breaks multi-epoch training.
The iterator from the upstream dataset should be re-initialized at the beginning of every __iter__ call to ensure each epoch processes the data correctly from the start.
self._data_iter = iter(self.dataset)There was a problem hiding this comment.
I think you are wrong, in the pytorch dataloader, when iter is started, it will copy the state of the dataset inited in the main process, in the main process, we only init it and the _data_iter is always None in the main process. Then the iter func will try to assign value to it in the beginning of looping so we should be fine.
|
should we simply replace the original one in this case? Why would we need two if their difference is just where to perform batching |
Luosuu
left a comment
There was a problem hiding this comment.
Let's add a feature switch to hide this first for users. Let's test this first
…dynamic batching construction
7cb0d4b to
469a99a
Compare
469a99a to
4dc4bbb
Compare
What does this PR do?
Add
DynamicBatchingSizeDataset, a new dataset-level dynamic batching approach that constructs micro batches inside DataLoader worker processes (as opposed to the existingDynamicBatchSizeDataLoaderwhich does it in the main process). This enables proper checkpoint/resume support viaStatefulDataLoader'sstate_dict()/load_state_dict()for dynamic batching, and is controlled by the newdyn_bsz_in_worker_loopflag.Test
All 9 tests pass (
pytest tests/data/test_dynamic_batching_dataset.py):get_itemtorchrun --nproc_per_node=2: shuffle × save_by_idx (2×2 combinations), each running 2 epochs of training with checkpoint save at step 2 of epoch 1 and resume verificationAPI and Usage Example
New arguments:
When
dyn_bsz_in_worker_loop=false,build_native_dataloaderwraps the dataset withDynamicBatchingSizeDataset, which yields pre-batched micro batches from each worker.When
dyn_bsz_in_worker_loop=true(default, existing behavior), the originalDynamicBatchSizeDataLoaderpath is used — no behavior change.Design & Code Changes
Core:
DynamicBatchingSizeDataset(veomni/data/dynamic_batching.py)IterableDatasetsubclass that buffers samples in each worker and yields micro batches when the buffer has ≥ready_for_micro_batch_thresholdsamples and ≥micro_batch_seq_lengthtotal tokens._get_micro_batch(): iterates buffer, selects samples fitting within the token budget.state_dict()/load_state_dict()forStatefulDataLoadercheckpoint/resume:save_by_idx=True: saves only sample indices (smaller checkpoint), requires dataset to supportget_item()andoutput_refetch_idx.save_by_idx=False: saves full buffer contents viadeepcopy.Integration (
veomni/data/data_loader.py)build_native_dataloadernow branches ondyn_bsz_in_worker_loop:True: existingTextBatchingStrategy+DynamicBatchSizeDataLoaderpath.False: wraps dataset inDynamicBatchingSizeDataset, usesNoopDataCollator, returnsStatefulDataLoaderdirectly (noDynamicBatchSizeDataLoaderwrapper).Arguments (
veomni/arguments/arguments_types.py)dyn_bsz_in_worker_loop: bool = True— controls which dynamic batching path to use.dyn_bsz_dataset_save_by_idx: bool = True— controls checkpoint buffer serialization strategy.dataloader_batch_sizecalculation fordyn_bsz_in_worker_loop=Falsecase.Minor refactor (
veomni/data/batching_strategy.py)BaseBatchingStrategy.is_full_filled()→is_ready_for_micro_batch()for clarity.Tests (
tests/data/test_dynamic_batching_dataset.py,tests/data/utils.py)DummyMappingDataset/DummyIterableDataset: test fixtures with configurable sequence lengths, rank/worker sharding (likeStreamingDatasetin ByteDance), andstate_dict()support.