Skip to content

Commit 4dc4bbb

Browse files
committed
resolve comments
1 parent 1424f90 commit 4dc4bbb

File tree

6 files changed

+211
-113
lines changed

6 files changed

+211
-113
lines changed

tests/data/test_dynamic_batching_dataset.py

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,35 @@
11
"""Tests for DynamicBatchingSizeDataset functionality.
22
3-
This module tests the DynamicBatchingSizeDataset class using DummyIterableDataset
4-
and DummyMappingDataset. It validates that DynamicBatchingSizeDataset can properly:
3+
This module tests the ``DynamicBatchingSizeDataset`` class using ``DummyIterableDataset``.
4+
It validates that ``DynamicBatchingSizeDataset`` can properly:
55
6-
1. Batch samples based on token count (micro_batch_seq_length)
7-
2. Handle buffer management with ready_for_micro_batch_threshold
8-
3. Work with both shuffled and non-shuffled iterable datasets
9-
4. Support state_dict save/load for checkpointing in distributed environments
6+
1. Batch samples based on token count (``micro_batch_seq_length``).
7+
2. Handle buffer management with ``ready_for_micro_batch_threshold``.
8+
3. Work with both shuffled and non-shuffled iterable datasets.
9+
4. Drain remaining buffer contents after the upstream dataset is exhausted.
10+
5. Reject invalid construction arguments (``save_by_idx`` without ``get_item``).
11+
6. Save and restore buffer state for exact checkpoint / resume in distributed
12+
environments, both by storing full samples and by storing only indices.
1013
1114
The test suite includes:
12-
- Unit tests that can run without distributed setup:
13-
- test_dynamic_batching_basic
14-
- End-to-end tests that require multi-GPU distributed environments:
15-
- test_dynamic_batching_dataset_shuffled
16-
- test_dynamic_batching_dataset_no_shuffle
15+
16+
Unit tests (run without distributed setup, CPU-compatible):
17+
- ``test_dynamic_batching_basic`` – core batching logic and expected batch
18+
contents for shuffled and non-shuffled data.
19+
- ``test_force_long_sequence`` – overlong samples are emitted rather than
20+
dropped when ``force_generate_long_sequence=True``.
21+
- ``test_last_batch_on_dataset_end`` – remaining buffer items are yielded
22+
after upstream exhaustion.
23+
- ``test_dynamic_batching_without_get_item`` – ``ValueError`` is raised when
24+
``save_by_idx=True`` but the dataset lacks ``get_item``.
25+
26+
End-to-end distributed tests (require ``torchrun`` with 2 processes):
27+
- ``test_dynamic_batching_dataset_distributed`` – parametrised over
28+
``shuffle × save_by_idx`` (4 combinations), verifying that resumed
29+
batches are byte-for-byte identical to the original run.
1730
"""
1831

32+
import argparse
1933
import os
2034
import subprocess
2135
import sys
@@ -279,7 +293,7 @@ def test_last_batch_on_dataset_end(setup_dynamic_batching_dataset):
279293

280294

281295
def test_dynamic_batching_without_get_item():
282-
"""Test DynamicBatchingSizeDataset initialization without get_item povided.
296+
"""Test DynamicBatchingSizeDataset initialization without get_item provided.
283297
284298
Tests that DynamicBatchingSizeDataset cannot be initialized with save_by_idx=True
285299
when the dataset doesn't have get_item method.
@@ -316,7 +330,7 @@ def __iter__(self):
316330
def test_dynamic_batching_dataset_distributed(shuffle, save_by_idx):
317331
"""Test DynamicBatchingSizeDataset in distributed setting.
318332
319-
Runs main_distributed_test() by torchrun with or without data shuffling
333+
Runs _main_distributed_test() by torchrun with or without data shuffling
320334
and with or without save_by_idx for checkpoint buffer saving.
321335
322336
Args:
@@ -358,16 +372,18 @@ def build_command(shuffle=True, save_by_idx=True):
358372
"--data.train_size=2000",
359373
"--data.max_seq_len=16",
360374
"--train.micro_batch_size=2",
361-
f"--data.shuffle={str(shuffle).lower()}",
375+
# NOTE: Do not rely on veomni_patch adding `data.shuffle` into DataArguments.
376+
# Keep this as a test-only flag (parsed via argparse in _run_distributed_test).
377+
f"--shuffle={str(shuffle).lower()}",
362378
"--train.global_batch_size=16",
363379
"--train.data_parallel_mode=ddp",
364380
"--train.ckpt_manager=dcp",
365381
"--train.output_dir=.tests/cache",
366382
"--train.rmpad=false",
367383
"--train.rmpad_with_pos_ids=true",
368384
"--train.dyn_bsz=true",
369-
"--train.dyn_bsz_in_worker_loop=false",
370-
f"--train.dyn_bsz_dataset_save_by_idx={str(save_by_idx).lower()}",
385+
"--dyn_bsz_in_dataloader=false",
386+
f"--save_by_idx={str(save_by_idx).lower()}",
371387
"--train.seed=42",
372388
]
373389
return command
@@ -389,20 +405,47 @@ class Arguments:
389405
train: "TrainingArguments" = field(default_factory=TrainingArguments)
390406

391407

392-
def main_distributed_test():
393-
"""
394-
Tests:
395-
- Dynamic batching with shuffled iterable dataset
396-
- Checkpoint save/load with buffer state
397-
- Multi-process distributed training
408+
def _main_distributed_test():
409+
"""Entry point for the distributed test launched by ``torchrun``.
410+
411+
It wraps ``_run_distributed_test()` and in the testing it is supposed to be
412+
triggered by test_dynamic_batching_dataset_distributed().
398413
"""
399414
# Patch empty_cache to avoid AttributeError on CPU
400415
with patch("veomni.utils.device.empty_cache", _mock_empty_cache):
401416
_run_distributed_test()
402417

403418

404419
def _run_distributed_test():
405-
"""Internal function that runs the actual distributed test."""
420+
"""Run a full checkpoint-resume cycle and assert batch reproducibility.
421+
422+
Procedure
423+
---------
424+
1. **Parse CLI flags**
425+
2. **Initialise torch distributed state**
426+
3. **Build a StatefulDataLoader** wrapping ``DummyIterableDataset`` →
427+
``DynamicBatchingSizeDataset`` with ``num_workers=2``.
428+
4. **First pass (2 epochs)** – iterate the dataloader for both epochs. Batches
429+
before the designated save point (``epoch=1, step=2``) are discarded; batches
430+
*after* that point are stored in ``batches_after_save_step`` as ground truth.
431+
At the save point a checkpoint is written via ``Checkpointer.save()``,
432+
capturing model weights, ``dataloader.state_dict()``, and
433+
``environ_meter.state_dict()``.
434+
5. **Load checkpoint** – ``Checkpointer.load()`` restores all state; the
435+
dataloader, dataset and environ-meter are restored through ``load_state_dict()``.
436+
6. **Second pass (resume)** – iterate from the saved epoch / step through the
437+
end of both epochs, collecting resumed batches in ``batch_after_resume``.
438+
7. **Assert equality** – verify that ``batches_after_save_step`` and
439+
``batch_after_resume`` have the same length and that every tensor in every
440+
micro-batch is identical element-wise.
441+
"""
442+
_parser = argparse.ArgumentParser()
443+
_parser.add_argument("--shuffle", type=lambda x: x.lower() == "true", default=True)
444+
_parser.add_argument("--save_by_idx", type=lambda x: x.lower() == "true", default=True)
445+
_parser.add_argument("--dyn_bsz_in_dataloader", type=lambda x: x.lower() == "true", default=True)
446+
test_args, remaining_argv = _parser.parse_known_args()
447+
sys.argv = [sys.argv[0]] + remaining_argv
448+
406449
args = parse_args(Arguments)
407450
world_size = int(os.environ["WORLD_SIZE"])
408451
rank = int(os.environ["RANK"])
@@ -433,8 +476,7 @@ def _run_distributed_test():
433476

434477
# Create DummyMappingDataset and DummyIterableDataset
435478
mapping_dataset = DummyMappingDataset(size=DATASET_SIZE)
436-
shuffle = getattr(args.data, "shuffle", True)
437-
iterable_dataset = DummyIterableDataset(mapping_dataset, shuffle=shuffle, seed=args.train.seed)
479+
iterable_dataset = DummyIterableDataset(mapping_dataset, shuffle=test_args.shuffle, seed=args.train.seed)
438480

439481
# Compute train_steps based on dataset size
440482
dataset_length = len(mapping_dataset)
@@ -452,11 +494,11 @@ def _run_distributed_test():
452494
train_steps=train_steps,
453495
rmpad=args.train.rmpad,
454496
dyn_bsz=args.train.dyn_bsz,
455-
dyn_bsz_in_worker_loop=args.train.dyn_bsz_in_worker_loop,
497+
dyn_bsz_in_dataloader=test_args.dyn_bsz_in_dataloader,
456498
bsz_warmup_ratio=args.train.bsz_warmup_ratio,
457499
rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
458500
dyn_bsz_buffer_size=READY_FOR_MICRO_BATCH_THRESHOLD,
459-
dyn_bsz_dataset_save_by_idx=args.train.dyn_bsz_dataset_save_by_idx,
501+
dyn_bsz_dataset_save_by_idx=test_args.save_by_idx,
460502
num_workers=2,
461503
drop_last=False,
462504
pin_memory=args.data.pin_memory,
@@ -472,7 +514,6 @@ def _run_distributed_test():
472514
empty_cache_steps=args.train.empty_cache_steps,
473515
)
474516

475-
batches_before_save_step = []
476517
batches_after_save_step = []
477518
epoch_num = 2 # Run 2 epochs
478519
start_epoch, start_step, global_step = 0, 0, 0
@@ -504,18 +545,16 @@ def _run_distributed_test():
504545

505546
# Print batch info for debugging
506547
"""
507-
logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)}")
548+
logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)} dataset_iter: {dataloader.dataset._data_iter}")
508549
for micro_idx, micro_batch in enumerate(micro_batches):
509550
# Extract sample indices from input_ids (each sample has all same values)
510551
input_ids = micro_batch["input_ids"].squeeze(0) # Remove batch dim
511552
input_ids = set(input_ids.tolist())
512-
logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
553+
logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
513554
"""
514555

515556
if epoch > save_epoch or (epoch == save_epoch and local_step > save_step):
516557
batches_after_save_step.append(micro_batches)
517-
else:
518-
batches_before_save_step.append(micro_batches)
519558

520559
for _, micro_batch in enumerate(micro_batches):
521560
environ_meter.add(micro_batch)
@@ -623,4 +662,4 @@ def _run_distributed_test():
623662

624663

625664
if __name__ == "__main__":
626-
main_distributed_test()
665+
_main_distributed_test()

tests/data/utils.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,39 @@
1717

1818

1919
class DummyMappingDataset(Dataset):
20-
"""Mapping-style dataset that generates dummy data based on index."""
20+
"""Mapping-style dataset that generates deterministic dummy samples by index.
21+
22+
* Sample at 0-based index ``i`` contains **i + 1** tokens, each with value
23+
``i + 1``. For example index 0 → ``[1]``, index 4 → ``[5, 5, 5, 5, 5]``.
24+
"""
2125

2226
def __init__(self, size: int = 100):
2327
"""
2428
Args:
25-
size: Total number of samples in the dataset
29+
size: Total number of samples in the dataset.
2630
"""
2731
self.size = size
2832

2933
def __len__(self):
3034
return self.size
3135

3236
def __getitem__(self, idx):
33-
"""Generate data following the same pattern as DummyDataset.generate_data"""
37+
"""Return the dummy sample at position *idx*.
38+
39+
Args:
40+
idx: 0-based integer index into the dataset.
41+
42+
Returns:
43+
dict with keys:
44+
45+
* ``"input_ids"`` – 1-D ``LongTensor`` of length ``idx + 1``, filled
46+
with the scalar value ``idx + 1``.
47+
* ``"attention_mask"`` – all-ones tensor of the same shape.
48+
* ``"labels"`` – clone of ``input_ids``.
49+
50+
Raises:
51+
IndexError: If ``idx`` is outside ``[0, size)``.
52+
"""
3453
if idx < 0 or idx >= self.size:
3554
raise IndexError(f"Index {idx} out of range [0, {self.size})")
3655

@@ -41,14 +60,34 @@ def __getitem__(self, idx):
4160

4261

4362
class DummyIterableDataset(IterableDataset):
44-
"""Iterable dataset that reads from DummyMappingDataset sequentially or with shuffle."""
63+
"""Iterable wrapper around ``DummyMappingDataset`` with built-in sharding and optional shuffle.
64+
65+
Designed to tested with ``DynamicBatchingSizeDataset`` and ``StatefulDataLoader`` checkpointing:
66+
67+
* **Sharding** – samples are distributed across distributed ranks *and* DataLoader
68+
workers using a round-robin interleave strategy (rank-major, then worker-minor),
69+
so each dataloader worker on each rank sees a disjoint, deterministic subset of the data.
70+
* **Shuffle** – when ``shuffle=True``, a fixed ``torch.randperm`` generated from
71+
``seed`` at construction time is used so that the shuffled order is reproducible
72+
and consistent across checkpoint / resume cycles.
73+
* **Index output** – when ``output_refetch_idx`` is set to ``True`` (by
74+
``DynamicBatchingSizeDataset`` when ``save_by_idx=True``), each ``__iter__``
75+
yield is a ``(sample_dict, original_index)`` tuple instead of a bare dict,
76+
allowing the consumer to store the indices instead of the full samples when saving checkpoints,
77+
and reconstruct the buffer from indices on resume.
78+
* **State dict** – ``state_dict()`` / ``load_state_dict()`` persist
79+
``_current_idx`` so that ``StatefulDataLoader`` can snapshot and restore the
80+
exact position of the iterator.
81+
"""
4582

4683
def __init__(self, mapping_dataset: DummyMappingDataset, shuffle: bool = False, seed: int = 42):
4784
"""
4885
Args:
49-
mapping_dataset: The underlying DummyMappingDataset to read from
50-
shuffle: Whether to shuffle the reading order
51-
seed: Random seed for shuffling
86+
mapping_dataset: The upstream ``DummyMappingDataset`` to read from.
87+
shuffle: Whether to shuffle the reading order. Shuffling is performed
88+
once at construction time using ``seed`` so that it is stable across
89+
distributed workers.
90+
seed: Random seed used to generate the permutation when ``shuffle=True``.
5291
"""
5392
self.mapping_dataset = mapping_dataset
5493
self.shuffle = shuffle
@@ -111,6 +150,18 @@ def __iter__(self):
111150
yield self.mapping_dataset[idx]
112151

113152
def get_item(self, idx):
153+
"""Fetch a single sample by its original dataset index.
154+
155+
Used by ``DynamicBatchingSizeDataset.load_state_dict()`` to reconstruct
156+
buffer contents when ``save_by_idx=True``: the saved indices are passed
157+
back here one-by-one to rebuild the exact pre-checkpoint buffer.
158+
159+
Args:
160+
idx: 0-based integer index into the underlying ``DummyMappingDataset``.
161+
162+
Returns:
163+
Sample as returned by ``DummyMappingDataset.__getitem__``.
164+
"""
114165
return self.mapping_dataset[idx]
115166

116167
def state_dict(self):

veomni/arguments/arguments_types.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -420,22 +420,10 @@ class TrainingArguments:
420420
default="worker",
421421
metadata={"help": "Use main process or worker process to run dynamic batch size."},
422422
)
423-
dyn_bsz_in_worker_loop: bool = field(
424-
default=True,
425-
metadata={
426-
"help": "Whether the dynamic batch construction is in DataLoader's worker loop or in Dataset's iterator."
427-
},
428-
)
429423
dyn_bsz_buffer_size: int = field(
430424
default=200,
431425
metadata={"help": "Buffer size for dynamic batch size."},
432426
)
433-
dyn_bsz_dataset_save_by_idx: bool = field(
434-
default=True,
435-
metadata={
436-
"help": "When dyn_bsz_in_worker_loop is False, it is to decide whether to save buffer by index for checkpointing in DynamicBatchingSizeDataset."
437-
},
438-
)
439427
bsz_warmup_ratio: float = field(
440428
default=0,
441429
metadata={"help": "Ratio of batch size warmup steps."},
@@ -740,15 +728,9 @@ def __post_init__(self):
740728

741729
# calculate dataloader batch size
742730
# for:
743-
# - DynamicBatchingSizeDataset and StatefulDataLoader
744731
# - StreamingDataset and StreamingDataLoader
745-
if (self.rmpad or self.rmpad_with_pos_ids) and self.dyn_bsz:
746-
if self.dyn_bsz_in_worker_loop:
747-
self.dataloader_batch_size = 1
748-
else:
749-
self.dataloader_batch_size = self.global_batch_size // (
750-
self.micro_batch_size * self.data_parallel_size
751-
)
732+
if (self.rmpad or self.rmpad_with_pos_ids) and self.dyn_bsz_runtime == "worker" and self.dyn_bsz:
733+
self.dataloader_batch_size = 1
752734
else:
753735
self.dataloader_batch_size = self.global_batch_size // self.data_parallel_size # = micro bsz * grad accu
754736

veomni/data/batching_strategy.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,6 @@ def merge(self, buffer_to_merge: "DynBszBuffer"):
9292
self.append(item)
9393

9494

95-
class IdentityPacker:
96-
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
97-
self.token_micro_bsz = token_micro_bsz
98-
self.bsz_warmup_steps = bsz_warmup_steps
99-
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
100-
101-
def __call__(self, samples):
102-
return samples
103-
104-
def get_token_num_to_request(self, cur_step, warmup):
105-
return (
106-
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
107-
+ self.bsz_warmup_init_mbtoken
108-
if warmup
109-
else self.token_micro_bsz
110-
)
111-
112-
11395
class BaseBatchingStrategy:
11496
"""
11597
Base class for batching strategy.
@@ -128,6 +110,24 @@ def empty(self) -> bool:
128110
raise NotImplementedError("should implement `empty`")
129111

130112

113+
class IdentityPacker:
114+
def __init__(self, token_micro_bsz, bsz_warmup_steps, bsz_warmup_init_mbtoken):
115+
self.token_micro_bsz = token_micro_bsz
116+
self.bsz_warmup_steps = bsz_warmup_steps
117+
self.bsz_warmup_init_mbtoken = bsz_warmup_init_mbtoken
118+
119+
def __call__(self, samples):
120+
return samples
121+
122+
def get_token_num_to_request(self, cur_step, warmup):
123+
return (
124+
(self.token_micro_bsz - self.bsz_warmup_init_mbtoken) * cur_step // self.bsz_warmup_steps
125+
+ self.bsz_warmup_init_mbtoken
126+
if warmup
127+
else self.token_micro_bsz
128+
)
129+
130+
131131
class TextBatchingStrategy(BaseBatchingStrategy):
132132
""" "
133133
Batching strategy for text data.

0 commit comments

Comments
 (0)