@@ -39,6 +39,7 @@ def setup(
3939 model_name : Optional [str ] = None ,
4040 model_config : Optional [Config ] = None ,
4141 out_dir : Path = Path ("out/pretrain" ),
42+ initial_checkpoint_dir : Optional [Path ] = None ,
4243 resume : Union [bool , Path ] = False ,
4344 data : Optional [DataModule ] = None ,
4445 train : TrainArgs = TrainArgs (
@@ -71,6 +72,8 @@ def setup(
7172 ``model_config``.
7273 out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
7374 /teamspace/jobs/<job-name>/share.
75+ initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
76+ Useful for continued pretraining. Mutually exclusive with ``resume``.
7477 resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
7578 from the latest checkpoint in ``out_dir``.
7679 data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
@@ -107,13 +110,14 @@ def setup(
107110 if logger_name in ("tensorboard" , "wandb" ):
108111 fabric .logger .log_hyperparams (hparams )
109112
110- main (fabric , devices , seed , resume , config , data , out_dir , tokenizer_dir , tokenizer , train , eval )
113+ main (fabric , devices , seed , initial_checkpoint_dir , resume , config , data , out_dir , tokenizer_dir , tokenizer , train , eval )
111114
112115
113116def main (
114117 fabric : L .Fabric ,
115118 devices : int ,
116119 seed : int ,
120+ initial_checkpoint_dir : Optional [Path ],
117121 resume : Union [bool , Path ],
118122 config : Config ,
119123 data : DataModule ,
@@ -123,7 +127,7 @@ def main(
123127 train : TrainArgs ,
124128 eval : EvalArgs ,
125129) -> None :
126- validate_args (train , eval )
130+ validate_args (train , eval , initial_checkpoint_dir , resume )
127131
128132 if fabric .global_rank == 0 :
129133 out_dir .mkdir (parents = True , exist_ok = True )
@@ -157,6 +161,9 @@ def main(
157161 train_dataloader , val_dataloader = get_dataloaders (fabric , data , tokenizer , train , model .max_seq_length )
158162 train_dataloader , val_dataloader = fabric .setup_dataloaders (train_dataloader , val_dataloader )
159163
164+ if initial_checkpoint_dir :
165+ fabric .load_raw (initial_checkpoint_dir / "lit_model.pth" , model )
166+
160167 state = {
161168 "model" : model ,
162169 "optimizer" : optimizer ,
@@ -376,7 +383,7 @@ def init_out_dir(out_dir: Path) -> Path:
376383 return out_dir
377384
378385
379- def validate_args (train : TrainArgs , eval : EvalArgs ) -> None :
386+ def validate_args (train : TrainArgs , eval : EvalArgs , initial_checkpoint_dir , resume ) -> None :
380387 issues = []
381388 unsupported = [
382389 (train , ["max_steps" , "epochs" ]),
@@ -391,6 +398,8 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
391398 for name in names :
392399 if getattr (args , name ) is None :
393400 issues .append (f"{ __file__ } requires the { name !r} argument. This is set in { args } " )
401+ if initial_checkpoint_dir and resume :
402+ issues .append ("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one." )
394403 if issues :
395404 raise ValueError ("\n " .join (issues ))
396405
0 commit comments