11"""Tests for DynamicBatchingSizeDataset functionality.
22
3- This module tests the DynamicBatchingSizeDataset class using DummyIterableDataset
4- and DummyMappingDataset. It validates that DynamicBatchingSizeDataset can properly:
3+ This module tests the `` DynamicBatchingSizeDataset`` class using `` DummyIterableDataset``.
4+ It validates that `` DynamicBatchingSizeDataset`` can properly:
55
6- 1. Batch samples based on token count (micro_batch_seq_length)
7- 2. Handle buffer management with ready_for_micro_batch_threshold
8- 3. Work with both shuffled and non-shuffled iterable datasets
9- 4. Support state_dict save/load for checkpointing in distributed environments
6+ 1. Batch samples based on token count (``micro_batch_seq_length``).
7+ 2. Handle buffer management with ``ready_for_micro_batch_threshold``.
8+ 3. Work with both shuffled and non-shuffled iterable datasets.
9+ 4. Drain remaining buffer contents after the upstream dataset is exhausted.
10+ 5. Reject invalid construction arguments (``save_by_idx`` without ``get_item``).
11+ 6. Save and restore buffer state for exact checkpoint / resume in distributed
12+ environments, both by storing full samples and by storing only indices.
1013
1114The test suite includes:
12- - Unit tests that can run without distributed setup:
13- - test_dynamic_batching_basic
14- - End-to-end tests that require multi-GPU distributed environments:
15- - test_dynamic_batching_dataset_shuffled
16- - test_dynamic_batching_dataset_no_shuffle
15+
16+ Unit tests (run without distributed setup, CPU-compatible):
17+ - ``test_dynamic_batching_basic`` – core batching logic and expected batch
18+ contents for shuffled and non-shuffled data.
19+ - ``test_force_long_sequence`` – overlong samples are emitted rather than
20+ dropped when ``force_generate_long_sequence=True``.
21+ - ``test_last_batch_on_dataset_end`` – remaining buffer items are yielded
22+ after upstream exhaustion.
23+ - ``test_dynamic_batching_without_get_item`` – ``ValueError`` is raised when
24+ ``save_by_idx=True`` but the dataset lacks ``get_item``.
25+
26+ End-to-end distributed tests (require ``torchrun`` with 2 processes):
27+ - ``test_dynamic_batching_dataset_distributed`` – parametrised over
28+ ``shuffle × save_by_idx`` (4 combinations), verifying that resumed
29+ batches are byte-for-byte identical to the original run.
1730"""
1831
32+ import argparse
1933import os
2034import subprocess
2135import sys
@@ -279,7 +293,7 @@ def test_last_batch_on_dataset_end(setup_dynamic_batching_dataset):
279293
280294
281295def test_dynamic_batching_without_get_item ():
282- """Test DynamicBatchingSizeDataset initialization without get_item povided .
296+ """Test DynamicBatchingSizeDataset initialization without get_item provided .
283297
284298 Tests that DynamicBatchingSizeDataset cannot be initialized with save_by_idx=True
285299 when the dataset doesn't have get_item method.
@@ -316,7 +330,7 @@ def __iter__(self):
316330def test_dynamic_batching_dataset_distributed (shuffle , save_by_idx ):
317331 """Test DynamicBatchingSizeDataset in distributed setting.
318332
319- Runs main_distributed_test () by torchrun with or without data shuffling
333+ Runs _main_distributed_test () by torchrun with or without data shuffling
320334 and with or without save_by_idx for checkpoint buffer saving.
321335
322336 Args:
@@ -358,16 +372,18 @@ def build_command(shuffle=True, save_by_idx=True):
358372 "--data.train_size=2000" ,
359373 "--data.max_seq_len=16" ,
360374 "--train.micro_batch_size=2" ,
361- f"--data.shuffle={ str (shuffle ).lower ()} " ,
375+ # NOTE: Do not rely on veomni_patch adding `data.shuffle` into DataArguments.
376+ # Keep this as a test-only flag (parsed via argparse in _run_distributed_test).
377+ f"--shuffle={ str (shuffle ).lower ()} " ,
362378 "--train.global_batch_size=16" ,
363379 "--train.data_parallel_mode=ddp" ,
364380 "--train.ckpt_manager=dcp" ,
365381 "--train.output_dir=.tests/cache" ,
366382 "--train.rmpad=false" ,
367383 "--train.rmpad_with_pos_ids=true" ,
368384 "--train.dyn_bsz=true" ,
369- "--train.dyn_bsz_in_worker_loop =false" ,
370- f"--train.dyn_bsz_dataset_save_by_idx ={ str (save_by_idx ).lower ()} " ,
385+ "--dyn_bsz_in_dataloader =false" ,
386+ f"--save_by_idx ={ str (save_by_idx ).lower ()} " ,
371387 "--train.seed=42" ,
372388 ]
373389 return command
@@ -389,20 +405,47 @@ class Arguments:
389405 train : "TrainingArguments" = field (default_factory = TrainingArguments )
390406
391407
392- def main_distributed_test ():
393- """
394- Tests:
395- - Dynamic batching with shuffled iterable dataset
396- - Checkpoint save/load with buffer state
397- - Multi-process distributed training
408+ def _main_distributed_test ():
409+ """Entry point for the distributed test launched by ``torchrun``.
410+
411+ It wraps ``_run_distributed_test()` and in the testing it is supposed to be
412+ triggered by test_dynamic_batching_dataset_distributed().
398413 """
399414 # Patch empty_cache to avoid AttributeError on CPU
400415 with patch ("veomni.utils.device.empty_cache" , _mock_empty_cache ):
401416 _run_distributed_test ()
402417
403418
404419def _run_distributed_test ():
405- """Internal function that runs the actual distributed test."""
420+ """Run a full checkpoint-resume cycle and assert batch reproducibility.
421+
422+ Procedure
423+ ---------
424+ 1. **Parse CLI flags**
425+ 2. **Initialise torch distributed state**
426+ 3. **Build a StatefulDataLoader** wrapping ``DummyIterableDataset`` →
427+ ``DynamicBatchingSizeDataset`` with ``num_workers=2``.
428+ 4. **First pass (2 epochs)** – iterate the dataloader for both epochs. Batches
429+ before the designated save point (``epoch=1, step=2``) are discarded; batches
430+ *after* that point are stored in ``batches_after_save_step`` as ground truth.
431+ At the save point a checkpoint is written via ``Checkpointer.save()``,
432+ capturing model weights, ``dataloader.state_dict()``, and
433+ ``environ_meter.state_dict()``.
434+ 5. **Load checkpoint** – ``Checkpointer.load()`` restores all state; the
435+ dataloader, dataset and environ-meter are restored through ``load_state_dict()``.
436+ 6. **Second pass (resume)** – iterate from the saved epoch / step through the
437+ end of both epochs, collecting resumed batches in ``batch_after_resume``.
438+ 7. **Assert equality** – verify that ``batches_after_save_step`` and
439+ ``batch_after_resume`` have the same length and that every tensor in every
440+ micro-batch is identical element-wise.
441+ """
442+ _parser = argparse .ArgumentParser ()
443+ _parser .add_argument ("--shuffle" , type = lambda x : x .lower () == "true" , default = True )
444+ _parser .add_argument ("--save_by_idx" , type = lambda x : x .lower () == "true" , default = True )
445+ _parser .add_argument ("--dyn_bsz_in_dataloader" , type = lambda x : x .lower () == "true" , default = True )
446+ test_args , remaining_argv = _parser .parse_known_args ()
447+ sys .argv = [sys .argv [0 ]] + remaining_argv
448+
406449 args = parse_args (Arguments )
407450 world_size = int (os .environ ["WORLD_SIZE" ])
408451 rank = int (os .environ ["RANK" ])
@@ -433,8 +476,7 @@ def _run_distributed_test():
433476
434477 # Create DummyMappingDataset and DummyIterableDataset
435478 mapping_dataset = DummyMappingDataset (size = DATASET_SIZE )
436- shuffle = getattr (args .data , "shuffle" , True )
437- iterable_dataset = DummyIterableDataset (mapping_dataset , shuffle = shuffle , seed = args .train .seed )
479+ iterable_dataset = DummyIterableDataset (mapping_dataset , shuffle = test_args .shuffle , seed = args .train .seed )
438480
439481 # Compute train_steps based on dataset size
440482 dataset_length = len (mapping_dataset )
@@ -452,11 +494,11 @@ def _run_distributed_test():
452494 train_steps = train_steps ,
453495 rmpad = args .train .rmpad ,
454496 dyn_bsz = args .train .dyn_bsz ,
455- dyn_bsz_in_worker_loop = args . train . dyn_bsz_in_worker_loop ,
497+ dyn_bsz_in_dataloader = test_args . dyn_bsz_in_dataloader ,
456498 bsz_warmup_ratio = args .train .bsz_warmup_ratio ,
457499 rmpad_with_pos_ids = args .train .rmpad_with_pos_ids ,
458500 dyn_bsz_buffer_size = READY_FOR_MICRO_BATCH_THRESHOLD ,
459- dyn_bsz_dataset_save_by_idx = args . train . dyn_bsz_dataset_save_by_idx ,
501+ dyn_bsz_dataset_save_by_idx = test_args . save_by_idx ,
460502 num_workers = 2 ,
461503 drop_last = False ,
462504 pin_memory = args .data .pin_memory ,
@@ -472,7 +514,6 @@ def _run_distributed_test():
472514 empty_cache_steps = args .train .empty_cache_steps ,
473515 )
474516
475- batches_before_save_step = []
476517 batches_after_save_step = []
477518 epoch_num = 2 # Run 2 epochs
478519 start_epoch , start_step , global_step = 0 , 0 , 0
@@ -504,18 +545,16 @@ def _run_distributed_test():
504545
505546 # Print batch info for debugging
506547 """
507- logger.info (f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)}")
548+ logger.error (f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)} dataset_iter: {dataloader.dataset._data_iter }")
508549 for micro_idx, micro_batch in enumerate(micro_batches):
509550 # Extract sample indices from input_ids (each sample has all same values)
510551 input_ids = micro_batch["input_ids"].squeeze(0) # Remove batch dim
511552 input_ids = set(input_ids.tolist())
512- logger.info (f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
553+ logger.error (f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
513554 """
514555
515556 if epoch > save_epoch or (epoch == save_epoch and local_step > save_step ):
516557 batches_after_save_step .append (micro_batches )
517- else :
518- batches_before_save_step .append (micro_batches )
519558
520559 for _ , micro_batch in enumerate (micro_batches ):
521560 environ_meter .add (micro_batch )
@@ -623,4 +662,4 @@ def _run_distributed_test():
623662
624663
625664if __name__ == "__main__" :
626- main_distributed_test ()
665+ _main_distributed_test ()
0 commit comments