Skip to content

Commit 3a982fa

Browse files
awaelchlirasbt
authored andcommitted
Enable continued pretraining (#1109)
1 parent 0546dcd commit 3a982fa

File tree

5 files changed

+51
-4
lines changed

5 files changed

+51
-4
lines changed

config_hub/pretrain/debug.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ model_config:
1111
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
1212
out_dir: out/pretrain/debug
1313

14+
# Optional path to a checkpoint directory to initialize the model from.
15+
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
16+
initial_checkpoint_dir:
17+
1418
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
1519
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
1620
resume: false

config_hub/pretrain/tinyllama.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ model_config:
1111
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
1212
out_dir: out/pretrain/tiny-llama
1313

14+
# Optional path to a checkpoint directory to initialize the model from.
15+
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
16+
initial_checkpoint_dir:
17+
1418
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
1519
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
1620
resume: false

config_hub/pretrain/tinystories.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ model_config:
2727
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
2828
out_dir: out/pretrain/stories15M
2929

30+
# Optional path to a checkpoint directory to initialize the model from.
31+
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
32+
initial_checkpoint_dir:
33+
3034
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
3135
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
3236
resume: false

litgpt/pretrain.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def setup(
3939
model_name: Optional[str] = None,
4040
model_config: Optional[Config] = None,
4141
out_dir: Path = Path("out/pretrain"),
42+
initial_checkpoint_dir: Optional[Path] = None,
4243
resume: Union[bool, Path] = False,
4344
data: Optional[DataModule] = None,
4445
train: TrainArgs = TrainArgs(
@@ -71,6 +72,8 @@ def setup(
7172
``model_config``.
7273
out_dir: Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
7374
/teamspace/jobs/<job-name>/share.
75+
initial_checkpoint_dir: Optional path to a checkpoint directory to initialize the model from.
76+
Useful for continued pretraining. Mutually exclusive with ``resume``.
7477
resume: Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
7578
from the latest checkpoint in ``out_dir``.
7679
data: Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
@@ -107,13 +110,14 @@ def setup(
107110
if logger_name in ("tensorboard", "wandb"):
108111
fabric.logger.log_hyperparams(hparams)
109112

110-
main(fabric, devices, seed, resume, config, data, out_dir, tokenizer_dir, tokenizer, train, eval)
113+
main(fabric, devices, seed, initial_checkpoint_dir, resume, config, data, out_dir, tokenizer_dir, tokenizer, train, eval)
111114

112115

113116
def main(
114117
fabric: L.Fabric,
115118
devices: int,
116119
seed: int,
120+
initial_checkpoint_dir: Optional[Path],
117121
resume: Union[bool, Path],
118122
config: Config,
119123
data: DataModule,
@@ -123,7 +127,7 @@ def main(
123127
train: TrainArgs,
124128
eval: EvalArgs,
125129
) -> None:
126-
validate_args(train, eval)
130+
validate_args(train, eval, initial_checkpoint_dir, resume)
127131

128132
if fabric.global_rank == 0:
129133
out_dir.mkdir(parents=True, exist_ok=True)
@@ -157,6 +161,9 @@ def main(
157161
train_dataloader, val_dataloader = get_dataloaders(fabric, data, tokenizer, train, model.max_seq_length)
158162
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
159163

164+
if initial_checkpoint_dir:
165+
fabric.load_raw(initial_checkpoint_dir / "lit_model.pth", model)
166+
160167
state = {
161168
"model": model,
162169
"optimizer": optimizer,
@@ -376,7 +383,7 @@ def init_out_dir(out_dir: Path) -> Path:
376383
return out_dir
377384

378385

379-
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
386+
def validate_args(train: TrainArgs, eval: EvalArgs, initial_checkpoint_dir, resume) -> None:
380387
issues = []
381388
unsupported = [
382389
(train, ["max_steps", "epochs"]),
@@ -391,6 +398,8 @@ def validate_args(train: TrainArgs, eval: EvalArgs) -> None:
391398
for name in names:
392399
if getattr(args, name) is None:
393400
issues.append(f"{__file__} requires the {name!r} argument. This is set in {args}")
401+
if initial_checkpoint_dir and resume:
402+
issues.append("Can't provide both `--resume` and `--initial_checkpoint_dir`. Choose one.")
394403
if issues:
395404
raise ValueError("\n".join(issues))
396405

tests/test_pretrain.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from io import StringIO
66
from pathlib import Path
77
from unittest import mock
8-
from unittest.mock import Mock
8+
from unittest.mock import Mock, ANY
99

1010
import pytest
1111
import torch
@@ -63,6 +63,32 @@ def test_pretrain(_, tmp_path):
6363
torch.distributed.barrier()
6464

6565

66+
67+
@RunIf(min_cuda_gpus=2, standalone=True)
68+
# Set CUDA_VISIBLE_DEVICES for FSDP hybrid-shard, if fewer GPUs are used than are available
69+
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"})
70+
@mock.patch("litgpt.pretrain.L.Fabric.load_raw")
71+
def test_initial_checkpoint_dir(load_mock, tmp_path):
72+
from litgpt import pretrain
73+
from litgpt.config import Config
74+
75+
model_config = Config(block_size=2, n_layer=2, n_embd=8, n_head=4, padded_vocab_size=8)
76+
77+
dataset = torch.tensor([[0, 1, 2], [3, 4, 5], [0, 1, 2]])
78+
dataloader = DataLoader(dataset)
79+
pretrain.get_dataloaders = Mock(return_value=(dataloader, dataloader))
80+
pretrain.fit = Mock()
81+
82+
pretrain.setup(
83+
initial_checkpoint_dir=tmp_path,
84+
devices=2,
85+
model_config=model_config,
86+
out_dir=tmp_path,
87+
)
88+
89+
load_mock.assert_called_once_with(tmp_path / "lit_model.pth", ANY)
90+
91+
6692
def test_pretrain_model_name_and_config():
6793
from litgpt import pretrain
6894
from litgpt.config import Config

0 commit comments

Comments
 (0)