Skip to content

Commit 3032035

Browse files
authored
Merge branch 'main' into data-fixes
2 parents a1fe22a + 503da7e commit 3032035

File tree

5 files changed

+5
-12
lines changed

5 files changed

+5
-12
lines changed

code-of-conduct.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Foundation Model Stack Community Code of Conduct
2+
3+
Please refer to [Foundation Model Stack Community Code of Conduct](https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md).

fms_fsdp/utils/train_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def train(
166166
ddp_stats.zero_()
167167
torch.cuda.reset_peak_memory_stats(device=torch.cuda.current_device())
168168

169-
if batch_idx % cfg.checkpoint_interval == 0:
169+
if batch_idx % cfg.checkpoint_interval == 0 or batch_idx == cfg.num_steps:
170170
checkpointer.save(
171171
batch_idx,
172172
model,

main_training_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,6 @@ def main(**kwargs):
169169
tokens_seen,
170170
)
171171

172-
checkpointer.save_single_file(cfg.num_steps, model)
173-
174172
dist.barrier()
175173
dist.destroy_process_group()
176174

main_training_mamba.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,6 @@ def main(**kwargs):
169169
tokens_seen,
170170
)
171171

172-
checkpointer.save_single_file(cfg.num_steps, model)
173-
174172
dist.barrier()
175173
dist.destroy_process_group()
176174

speculator/train_speculator_utils.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def train_speculator(
412412

413413
if (
414414
batch_idx % cfg.checkpoint_interval == 0
415+
or batch_idx == cfg.num_steps
415416
or do_ckpt(cfg.ckpt_save_path) is True
416417
):
417418
torch.cuda.empty_cache()
@@ -425,13 +426,6 @@ def train_speculator(
425426
torch.cuda.empty_cache()
426427
do_ckpt(cfg.ckpt_save_path, reset=True)
427428

428-
checkpointer.save_single_file(
429-
batch_idx,
430-
speculator,
431-
tokens_seen=elapsed_tokens + n_tok,
432-
is_compiled=cfg.use_torch_compile,
433-
)
434-
435429

436430
class EmbedGPTBigCode(GPTBigCode):
437431
# Overrides the forward function of GPTBigCode to allow returning embedding vectors

0 commit comments

Comments
 (0)