File tree Expand file tree Collapse file tree 5 files changed +5
-12
lines changed
Expand file tree Collapse file tree 5 files changed +5
-12
lines changed Original file line number Diff line number Diff line change 1+ # Foundation Model Stack Community Code of Conduct
2+
3+ Please refer to [ Foundation Model Stack Community Code of Conduct] ( https://github.com/foundation-model-stack/foundation-model-stack/blob/main/code-of-conduct.md ) .
Original file line number Diff line number Diff line change @@ -166,7 +166,7 @@ def train(
166166 ddp_stats .zero_ ()
167167 torch .cuda .reset_peak_memory_stats (device = torch .cuda .current_device ())
168168
169- if batch_idx % cfg .checkpoint_interval == 0 :
169+ if batch_idx % cfg .checkpoint_interval == 0 or batch_idx == cfg . num_steps :
170170 checkpointer .save (
171171 batch_idx ,
172172 model ,
Original file line number Diff line number Diff line change @@ -169,8 +169,6 @@ def main(**kwargs):
169169 tokens_seen ,
170170 )
171171
172- checkpointer .save_single_file (cfg .num_steps , model )
173-
174172 dist .barrier ()
175173 dist .destroy_process_group ()
176174
Original file line number Diff line number Diff line change @@ -169,8 +169,6 @@ def main(**kwargs):
169169 tokens_seen ,
170170 )
171171
172- checkpointer .save_single_file (cfg .num_steps , model )
173-
174172 dist .barrier ()
175173 dist .destroy_process_group ()
176174
Original file line number Diff line number Diff line change @@ -412,6 +412,7 @@ def train_speculator(
412412
413413 if (
414414 batch_idx % cfg .checkpoint_interval == 0
415+ or batch_idx == cfg .num_steps
415416 or do_ckpt (cfg .ckpt_save_path ) is True
416417 ):
417418 torch .cuda .empty_cache ()
@@ -425,13 +426,6 @@ def train_speculator(
425426 torch .cuda .empty_cache ()
426427 do_ckpt (cfg .ckpt_save_path , reset = True )
427428
428- checkpointer .save_single_file (
429- batch_idx ,
430- speculator ,
431- tokens_seen = elapsed_tokens + n_tok ,
432- is_compiled = cfg .use_torch_compile ,
433- )
434-
435429
436430class EmbedGPTBigCode (GPTBigCode ):
437431 # Overrides the forward function of GPTBigCode to allow returning embedding vectors
You can’t perform that action at this time.
0 commit comments