Skip to content

Commit 7cb0d4b

Browse files
committed
resolve comments
1 parent 340197c commit 7cb0d4b

File tree

3 files changed

+16
-25
lines changed

3 files changed

+16
-25
lines changed

tests/data/test_dynamic_batching_dataset.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
- test_dynamic_batching_dataset_no_shuffle
1717
"""
1818

19+
import argparse
1920
import os
2021
import subprocess
2122
import sys
@@ -366,8 +367,8 @@ def build_command(shuffle=True, save_by_idx=True):
366367
"--train.rmpad=false",
367368
"--train.rmpad_with_pos_ids=true",
368369
"--train.dyn_bsz=true",
369-
"--train.dyn_bsz_in_worker_loop=false",
370-
f"--train.dyn_bsz_dataset_save_by_idx={str(save_by_idx).lower()}",
370+
"--dyn_bsz_in_worker_loop=false",
371+
f"--save_by_idx={str(save_by_idx).lower()}",
371372
"--train.seed=42",
372373
]
373374
return command
@@ -403,6 +404,12 @@ def main_distributed_test():
403404

404405
def _run_distributed_test():
405406
"""Internal function that runs the actual distributed test."""
407+
_parser = argparse.ArgumentParser()
408+
_parser.add_argument("--save_by_idx", type=lambda x: x.lower() == "true", default=True)
409+
_parser.add_argument("--dyn_bsz_in_worker_loop", type=lambda x: x.lower() == "true", default=True)
410+
test_args, remaining_argv = _parser.parse_known_args()
411+
sys.argv = [sys.argv[0]] + remaining_argv
412+
406413
args = parse_args(Arguments)
407414
world_size = int(os.environ["WORLD_SIZE"])
408415
rank = int(os.environ["RANK"])
@@ -452,11 +459,11 @@ def _run_distributed_test():
452459
train_steps=train_steps,
453460
rmpad=args.train.rmpad,
454461
dyn_bsz=args.train.dyn_bsz,
455-
dyn_bsz_in_worker_loop=args.train.dyn_bsz_in_worker_loop,
462+
dyn_bsz_in_worker_loop=test_args.dyn_bsz_in_worker_loop,
456463
bsz_warmup_ratio=args.train.bsz_warmup_ratio,
457464
rmpad_with_pos_ids=args.train.rmpad_with_pos_ids,
458465
dyn_bsz_buffer_size=READY_FOR_MICRO_BATCH_THRESHOLD,
459-
dyn_bsz_dataset_save_by_idx=args.train.dyn_bsz_dataset_save_by_idx,
466+
dyn_bsz_dataset_save_by_idx=test_args.save_by_idx,
460467
num_workers=2,
461468
drop_last=False,
462469
pin_memory=args.data.pin_memory,
@@ -504,12 +511,12 @@ def _run_distributed_test():
504511

505512
# Print batch info for debugging
506513
"""
507-
logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)}")
514+
logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)} dataset_iter: {dataloader.dataset._data_iter}")
508515
for micro_idx, micro_batch in enumerate(micro_batches):
509516
# Extract sample indices from input_ids (each sample has all same values)
510517
input_ids = micro_batch["input_ids"].squeeze(0) # Remove batch dim
511518
input_ids = set(input_ids.tolist())
512-
logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
519+
logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}")
513520
"""
514521

515522
if epoch > save_epoch or (epoch == save_epoch and local_step > save_step):

veomni/arguments/arguments_types.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -396,22 +396,10 @@ class TrainingArguments:
396396
default="worker",
397397
metadata={"help": "Use main process or worker process to run dynamic batch size."},
398398
)
399-
dyn_bsz_in_worker_loop: bool = field(
400-
default=True,
401-
metadata={
402-
"help": "Whether the dynamic batch construction is in DataLoader's worker loop or in Dataset's iterator."
403-
},
404-
)
405399
dyn_bsz_buffer_size: int = field(
406400
default=200,
407401
metadata={"help": "Buffer size for dynamic batch size."},
408402
)
409-
dyn_bsz_dataset_save_by_idx: bool = field(
410-
default=True,
411-
metadata={
412-
"help": "When dyn_bsz_in_worker_loop is False, it is to decide whether to save buffer by index for checkpointing in DynamicBatchingSizeDataset."
413-
},
414-
)
415403
bsz_warmup_ratio: float = field(
416404
default=0,
417405
metadata={"help": "Ratio of batch size warmup steps."},
@@ -718,13 +706,8 @@ def __post_init__(self):
718706
# for:
719707
# - DynamicBatchingSizeDataset and StatefulDataLoader
720708
# - StreamingDataset and StreamingDataLoader
721-
if (self.rmpad or self.rmpad_with_pos_ids) and self.dyn_bsz:
722-
if self.dyn_bsz_in_worker_loop:
723-
self.dataloader_batch_size = 1
724-
else:
725-
self.dataloader_batch_size = self.global_batch_size // (
726-
self.micro_batch_size * self.data_parallel_size
727-
)
709+
if (self.rmpad or self.rmpad_with_pos_ids) and self.dyn_bsz and self.dyn_bsz_runtime == "worker":
710+
self.dataloader_batch_size = 1
728711
else:
729712
self.dataloader_batch_size = self.global_batch_size // self.data_parallel_size # = micro bsz * grad accu
730713

veomni/data/data_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def build_native_dataloader(
135135
)
136136
collate_fn = UnpackDataCollator()
137137
else:
138+
dataloader_batch_size = num_micro_batch
138139
dataset = DynamicBatchingSizeDataset(
139140
dataset=dataset,
140141
micro_batch_seq_length=token_micro_bsz,

0 commit comments

Comments
 (0)