Skip to content

Commit c912f27

Browse files
committed
make naming more descriptive, update default in test util
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 1dea1e9 commit c912f27

File tree

12 files changed

+33
-29
lines changed

12 files changed

+33
-29
lines changed

src/megatron/bridge/training/checkpointing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def save_checkpoint(
652652
save_strategy = TorchDistSaveShardedStrategy(
653653
"torch_dist",
654654
1,
655-
thread_count=ckpt_cfg.thread_count,
655+
thread_count=ckpt_cfg.storage_writers_per_rank,
656656
)
657657
else:
658658
save_strategy = get_default_save_sharded_strategy(ckpt_cfg.ckpt_format)

src/megatron/bridge/training/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -852,9 +852,9 @@ class CheckpointConfig:
852852
use_checkpoint_args: bool = False
853853
"""Override any command line arguments with arguments from the checkpoint"""
854854

855-
thread_count: int = 1
856-
"""Number of threads to use during saving (torch_dist format only).
857-
Affects the number of checkpoint files: saving_ranks * thread_count."""
855+
storage_writers_per_rank: int = 1
856+
"""Number of storage writers per rank for torch_dist checkpoint format.
857+
Affects the number of checkpoint files: saving_ranks * storage_writers_per_rank."""
858858

859859
exit_on_missing_checkpoint: bool = False
860860
"""If 'load' is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization."""

tests/functional_tests/recipes/test_gpt_oss_recipes_finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def test_gpt_oss_finetune_recipes(
276276
config.checkpoint.save,
277277
5,
278278
ckpt_format=config.checkpoint.ckpt_format,
279-
thread_count=config.checkpoint.thread_count,
279+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
280280
)
281281

282282
finally:

tests/functional_tests/recipes/test_llama_recipes_distill_3b-1b.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def run_distill_recipe_test(
166166
config.checkpoint.save,
167167
10,
168168
ckpt_format=config.checkpoint.ckpt_format,
169-
thread_count=config.checkpoint.thread_count,
169+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
170170
)
171171

172172
finally:

tests/functional_tests/recipes/test_nemotronh_recipes_finetune.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def test_nemotron_nano_v2_finetune_recipes(
304304
config.checkpoint.save,
305305
config.train.train_iters,
306306
ckpt_format=config.checkpoint.ckpt_format,
307-
thread_count=config.checkpoint.thread_count,
307+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
308308
)
309309

310310
finally:
@@ -582,7 +582,7 @@ def test_nemotron_3_nano_finetune_recipes(
582582
config.checkpoint.save,
583583
config.train.train_iters,
584584
ckpt_format=config.checkpoint.ckpt_format,
585-
thread_count=config.checkpoint.thread_count,
585+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
586586
)
587587

588588
finally:

tests/functional_tests/recipes/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def run_pretrain_recipe_test(
123123
config.checkpoint.save,
124124
10,
125125
ckpt_format=config.checkpoint.ckpt_format,
126-
thread_count=config.checkpoint.thread_count,
126+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
127127
)
128128

129129
finally:
@@ -291,7 +291,7 @@ def run_pretrain_vl_recipe_test(
291291
config.checkpoint.save,
292292
config.train.train_iters,
293293
ckpt_format=config.checkpoint.ckpt_format,
294-
thread_count=config.checkpoint.thread_count,
294+
storage_writers_per_rank=config.checkpoint.storage_writers_per_rank,
295295
)
296296

297297
finally:

tests/functional_tests/training/test_finetune_lora.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_pretrain_then_lora_finetune(self, tmp_path):
9191
pretrain_checkpoint_dir,
9292
pretrain_iters,
9393
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
94-
thread_count=pretrain_cfg.checkpoint.thread_count,
94+
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
9595
)
9696

9797
# Create LoRA config and run finetuning
@@ -103,7 +103,7 @@ def test_pretrain_then_lora_finetune(self, tmp_path):
103103
lora_checkpoint_dir,
104104
lora_iters,
105105
ckpt_format=lora_cfg.checkpoint.ckpt_format,
106-
thread_count=lora_cfg.checkpoint.thread_count,
106+
storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank,
107107
)
108108
verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters)
109109

@@ -143,7 +143,7 @@ def test_lora_save_and_resume(self, tmp_path):
143143
pretrain_checkpoint_dir,
144144
pretrain_iters,
145145
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
146-
thread_count=pretrain_cfg.checkpoint.thread_count,
146+
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
147147
)
148148

149149
# Second run: LoRA finetuning initial phase (will be "interrupted")
@@ -165,7 +165,7 @@ def test_lora_save_and_resume(self, tmp_path):
165165
lora_checkpoint_dir,
166166
initial_lora_iters,
167167
ckpt_format=lora_initial_cfg.checkpoint.ckpt_format,
168-
thread_count=lora_initial_cfg.checkpoint.thread_count,
168+
storage_writers_per_rank=lora_initial_cfg.checkpoint.storage_writers_per_rank,
169169
)
170170

171171
# Third run: Resume LoRA finetuning from checkpoint (adapter-only states)
@@ -189,7 +189,7 @@ def test_lora_save_and_resume(self, tmp_path):
189189
lora_checkpoint_dir,
190190
total_lora_iters,
191191
ckpt_format=lora_resume_cfg.checkpoint.ckpt_format,
192-
thread_count=lora_resume_cfg.checkpoint.thread_count,
192+
storage_writers_per_rank=lora_resume_cfg.checkpoint.storage_writers_per_rank,
193193
)
194194
verify_peft_checkpoint_smaller(
195195
pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, initial_lora_iters
@@ -227,7 +227,7 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path):
227227
pretrain_checkpoint_dir,
228228
pretrain_iters,
229229
ckpt_format=pretrain_cfg.checkpoint.ckpt_format,
230-
thread_count=pretrain_cfg.checkpoint.thread_count,
230+
storage_writers_per_rank=pretrain_cfg.checkpoint.storage_writers_per_rank,
231231
)
232232

233233
# Create LoRA config with packed sequences and run finetuning
@@ -248,7 +248,7 @@ def test_lora_finetune_with_packed_sequences(self, tmp_path):
248248
lora_checkpoint_dir,
249249
lora_iters,
250250
ckpt_format=lora_cfg.checkpoint.ckpt_format,
251-
thread_count=lora_cfg.checkpoint.thread_count,
251+
storage_writers_per_rank=lora_cfg.checkpoint.storage_writers_per_rank,
252252
)
253253
verify_peft_checkpoint_smaller(pretrain_checkpoint_dir, lora_checkpoint_dir, pretrain_iters, lora_iters)
254254

tests/functional_tests/training/test_megatron_fsdp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_fsdp_pretrain_with_checkpoint(self, tmp_path):
315315
checkpoint_dir,
316316
total_iters,
317317
ckpt_format=cfg.checkpoint.ckpt_format,
318-
thread_count=cfg.checkpoint.thread_count,
318+
storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank,
319319
)
320320

321321
finally:
@@ -364,7 +364,7 @@ def test_fsdp_pretrain_save_resume(self, tmp_path):
364364
checkpoint_dir,
365365
checkpoint_iters,
366366
ckpt_format=cfg_first.checkpoint.ckpt_format,
367-
thread_count=cfg_first.checkpoint.thread_count,
367+
storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank,
368368
)
369369

370370
torch.distributed.barrier()
@@ -390,7 +390,7 @@ def test_fsdp_pretrain_save_resume(self, tmp_path):
390390
checkpoint_dir,
391391
total_iters,
392392
ckpt_format=cfg_second.checkpoint.ckpt_format,
393-
thread_count=cfg_second.checkpoint.thread_count,
393+
storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank,
394394
)
395395

396396
finally:

tests/functional_tests/training/test_pretrain_resume.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_pretrain_save_load(self, tmp_path):
167167
checkpoint_dir,
168168
checkpoint_iters,
169169
ckpt_format=cfg_first.checkpoint.ckpt_format,
170-
thread_count=cfg_first.checkpoint.thread_count,
170+
storage_writers_per_rank=cfg_first.checkpoint.storage_writers_per_rank,
171171
)
172172

173173
torch.distributed.barrier()
@@ -257,7 +257,7 @@ def test_pretrain_save_load(self, tmp_path):
257257
checkpoint_dir,
258258
total_iters,
259259
ckpt_format=cfg_second.checkpoint.ckpt_format,
260-
thread_count=cfg_second.checkpoint.thread_count,
260+
storage_writers_per_rank=cfg_second.checkpoint.storage_writers_per_rank,
261261
)
262262

263263
finally:

tests/functional_tests/training/test_seqpacking_cp_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_sft_example_runs_with_cp_and_packing(self, tmp_path):
103103
checkpoint_dir,
104104
cfg.train.train_iters,
105105
ckpt_format=cfg.checkpoint.ckpt_format,
106-
thread_count=cfg.checkpoint.thread_count,
106+
storage_writers_per_rank=cfg.checkpoint.storage_writers_per_rank,
107107
)
108108
finally:
109109
clear_directories(shared_dir)

0 commit comments

Comments
 (0)