Skip to content

Commit 24b2980

Browse files
authored
Allow mixed-precision training in DDP/single-device scenarios (#36)
* Support more recent torchtitan version up to commit 0b44d4c, and allow AMP training in non-FSDP mode (as in pytorch/torchtitan@0b44d4c#diff-54e6f3c870acaf438db326aba3c3462b1848b4600cc37204de946da020805dd3) * Fixing --checkpoint.initial_load_model_weights_only
1 parent 6e49e20 commit 24b2980

File tree

3 files changed

+67
-12
lines changed

3 files changed

+67
-12
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pip uninstall flash-linear-attention && pip install -U --no-use-pep517 git+https
3030

3131
[Important] Install specific version of torchtitan
3232
```
33-
pip install git+https://github.com/pytorch/torchtitan.git@5e2033c
33+
pip install git+https://github.com/pytorch/torchtitan.git@0b44d4c
3434
```
3535

3636

flame/config_manager.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,18 @@ def __init__(self):
183183
self.parser.add_argument(
184184
"--optimizer.lr", type=float, default=8e-4, help="Learning rate to use"
185185
)
186+
self.parser.add_argument(
187+
"--optimizer.beta1", type=float, default=0.9,
188+
help="Exponential moving average hyperparameters to use"
189+
)
190+
self.parser.add_argument(
191+
"--optimizer.beta2", type=float, default=0.95,
192+
help="Exponential moving average hyperparameters to use"
193+
)
194+
self.parser.add_argument(
195+
"--optimizer.weight_decay", type=float, default=0.1,
196+
help="Weight decay to use"
197+
)
186198
self.parser.add_argument(
187199
"--optimizer.implementation",
188200
type=str,
@@ -407,8 +419,10 @@ def __init__(self):
407419
default="bfloat16",
408420
choices=["bfloat16", "float32"],
409421
help="""
410-
torch dtype to use for parameters when applying mixed precision via FSDP.
411-
This feature only takes effect when data_parallel_shard_degree > 1
422+
torch dtype to use for parameters when applying mixed precision via fully_shard or torch.autocast.
423+
This feature takes effect via fully_shard when data_parallel_shard_degree > 1 or
424+
context_parallel_degree > 1; it takes effect via torch.autocast when data_replicate_degree >= 1
425+
and no other parallelism is enabled, i.e. under DDP or single-device training.
412426
""",
413427
)
414428
self.parser.add_argument(
@@ -606,19 +620,54 @@ def __init__(self):
606620
When enable_checkpoint is set to true, checkpoints will be in {--job.dump_folder}/{--checkpoint.folder}.
607621
""",
608622
)
623+
self.parser.add_argument(
624+
"--checkpoint.initial_load_path", type=str, default=None,
625+
help="""
626+
This option specifies the path to the initial checkpoint to load, which is
627+
particularly useful for resuming training from a previous run with a
628+
different output path or when loading a checkpoint from a pre-trained model.
629+
If the checkpoint folder for the current run is not empty,
630+
located at {--job.dump_folder}/{--checkpoint.folder}, this option will be ignored.
631+
This feature allows users to load an initial checkpoint from a different folder and
632+
continue training, saving new checkpoints to the specified folder without affecting
633+
the existing ones.
634+
635+
Note that the path should contain the full path to the checkpoint folder,
636+
including the step number, if any; for example,
637+
"//pre_train/checkpoints/llama3/llama3_8b/step_10000".
638+
"""
639+
)
640+
self.parser.add_argument(
641+
"--checkpoint.initial_load_model_weights_only",
642+
dest='checkpoint.initial_load_model_weights_only', action="store_true", default=True,
643+
help="""
644+
This option specifies if only the model weights should be loaded during the initial
645+
checkpoint load. The option is only used when `initial_load_path` is specified, and
646+
only applies to a model_weights_only checkpoint. Loading a periodic checkpoint
647+
may lead to unexpected behavior if this option is set to True.
648+
If False, the checkpoint at `initial_load_path` is treated as a standard training
649+
checkpoint, including optimizer and training states.
650+
The default setting for this option is True. Note that you will have to use
651+
`--checkpoint.no_initial_load_model_weights_only` to override the default setting.
652+
"""
653+
)
654+
self.parser.add_argument(
655+
"--checkpoint.no_initial_load_model_weights_only",
656+
dest='checkpoint.initial_load_model_weights_only', action="store_false",
657+
)
609658
self.parser.add_argument(
610659
"--checkpoint.interval",
611660
type=int,
612661
default=500,
613662
help="Checkpointing interval in steps.",
614663
)
615664
self.parser.add_argument(
616-
"--checkpoint.model_weights_only",
665+
"--checkpoint.last_save_model_weights_only",
617666
action="store_true",
618667
help="""
619-
When model_weights_only=True, only model weights will be saved at the end of training.
620-
With this, checkpoints can be loaded using `torch.load(..., weights_only=True)` after conversion.
621-
When model_weights_only=False, the full checkpoint will be saved.
668+
When last_save_model_weights_only=True, only model weights will be saved at the end of training,
669+
the last save. With this, checkpoints can be loaded using `torch.load(..., weights_only=True)`
670+
after conversion. When last_save_model_weights_only=False, the full checkpoint will be saved.
622671
A full checkpoint includes model, optimizer and train_state, which can be used to resume training.
623672
The default value is false.
624673
""",

flame/train.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ def main(job_config: JobConfig):
350350
parallel_dims.loss_parallel_enabled,
351351
job_config.experimental.enable_compiled_autograd,
352352
)
353+
maybe_enable_amp = dist_utils.maybe_enable_amp(
354+
parallel_dims,
355+
job_config.training.mixed_precision_param,
356+
device_type,
357+
)
353358

354359
# variables used to keep info for metrics logging
355360
device_memory_monitor.reset_peak_stats()
@@ -484,11 +489,12 @@ def main(job_config: JobConfig):
484489
else:
485490
# Non-PP forward / backward
486491
with train_context(optional_context_parallel_ctx):
487-
output = model(
488-
input_ids=input_ids,
489-
labels=labels,
490-
position_ids=position_ids,
491-
cu_seqlens=cu_seqlens,
492+
with maybe_enable_amp:
493+
output = model(
494+
input_ids=input_ids,
495+
labels=labels,
496+
position_ids=position_ids,
497+
cu_seqlens=cu_seqlens,
492498
)
493499
loss = (
494500
output.loss

0 commit comments

Comments
 (0)