|
16 | 16 | - test_dynamic_batching_dataset_no_shuffle |
17 | 17 | """ |
18 | 18 |
|
| 19 | +import argparse |
19 | 20 | import os |
20 | 21 | import subprocess |
21 | 22 | import sys |
@@ -366,8 +367,8 @@ def build_command(shuffle=True, save_by_idx=True): |
366 | 367 | "--train.rmpad=false", |
367 | 368 | "--train.rmpad_with_pos_ids=true", |
368 | 369 | "--train.dyn_bsz=true", |
369 | | - "--train.dyn_bsz_in_worker_loop=false", |
370 | | - f"--train.dyn_bsz_dataset_save_by_idx={str(save_by_idx).lower()}", |
| 370 | + "--dyn_bsz_in_worker_loop=false", |
| 371 | + f"--save_by_idx={str(save_by_idx).lower()}", |
371 | 372 | "--train.seed=42", |
372 | 373 | ] |
373 | 374 | return command |
@@ -403,6 +404,12 @@ def main_distributed_test(): |
403 | 404 |
|
404 | 405 | def _run_distributed_test(): |
405 | 406 | """Internal function that runs the actual distributed test.""" |
| 407 | + _parser = argparse.ArgumentParser() |
| 408 | + _parser.add_argument("--save_by_idx", type=lambda x: x.lower() == "true", default=True) |
| 409 | + _parser.add_argument("--dyn_bsz_in_worker_loop", type=lambda x: x.lower() == "true", default=True) |
| 410 | + test_args, remaining_argv = _parser.parse_known_args() |
| 411 | + sys.argv = [sys.argv[0]] + remaining_argv |
| 412 | + |
406 | 413 | args = parse_args(Arguments) |
407 | 414 | world_size = int(os.environ["WORLD_SIZE"]) |
408 | 415 | rank = int(os.environ["RANK"]) |
@@ -452,11 +459,11 @@ def _run_distributed_test(): |
452 | 459 | train_steps=train_steps, |
453 | 460 | rmpad=args.train.rmpad, |
454 | 461 | dyn_bsz=args.train.dyn_bsz, |
455 | | - dyn_bsz_in_worker_loop=args.train.dyn_bsz_in_worker_loop, |
| 462 | + dyn_bsz_in_worker_loop=test_args.dyn_bsz_in_worker_loop, |
456 | 463 | bsz_warmup_ratio=args.train.bsz_warmup_ratio, |
457 | 464 | rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, |
458 | 465 | dyn_bsz_buffer_size=READY_FOR_MICRO_BATCH_THRESHOLD, |
459 | | - dyn_bsz_dataset_save_by_idx=args.train.dyn_bsz_dataset_save_by_idx, |
| 466 | + dyn_bsz_dataset_save_by_idx=test_args.save_by_idx, |
460 | 467 | num_workers=2, |
461 | 468 | drop_last=False, |
462 | 469 | pin_memory=args.data.pin_memory, |
@@ -504,12 +511,12 @@ def _run_distributed_test(): |
504 | 511 |
|
505 | 512 | # Print batch info for debugging |
506 | 513 | """ |
507 | | - logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)}") |
| 514 | + logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} num_micro_batches:{len(micro_batches)} dataset_iter: {dataloader.dataset._data_iter}") |
508 | 515 | for micro_idx, micro_batch in enumerate(micro_batches): |
509 | 516 | # Extract sample indices from input_ids (each sample has all same values) |
510 | 517 | input_ids = micro_batch["input_ids"].squeeze(0) # Remove batch dim |
511 | 518 | input_ids = set(input_ids.tolist()) |
512 | | - logger.info(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}") |
| 519 | + logger.error(f"[rank{rank}] epoch:{epoch} step:{local_step} global_step:{global_step} micro_batch[{micro_idx}]: {input_ids}") |
513 | 520 | """ |
514 | 521 |
|
515 | 522 | if epoch > save_epoch or (epoch == save_epoch and local_step > save_step): |
|
0 commit comments