Skip to content

Commit cabec5f

Browse files
astrobdrDaniel Shats Daniel.Shats1@ibm.comBorda
authored
example for full finetuning with python code (#1331)
Co-authored-by: Daniel Shats [email protected] <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent d8bf47a commit cabec5f

File tree

1 file changed

+122
-0
lines changed

1 file changed

+122
-0
lines changed

tutorials/full_finetune_example.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""
2+
This script is meant to be the simplest possible starting point for full finetuning a GPT model using lightning fabric with code (not CLI).
3+
4+
- no checkpoints
5+
- no out dir
6+
- no precision
7+
- no resume
8+
- no train/eval args (or any args in general)
9+
- no logger (only to terminal)
10+
- no grad accumulation
11+
and no other fancy stuff.
12+
13+
To add all the above stuff, you can slowly add them in yourself by looking at the code in litgpt/finetune/full.py or the docs for litgpt/fabric.
14+
"""
15+
16+
import os
17+
18+
import lightning as L
19+
import torch
20+
import torch.nn as nn
21+
22+
from litgpt.data import Alpaca
23+
from litgpt.model import GPT, Config
24+
from litgpt.tokenizer import Tokenizer
25+
from litgpt.utils import num_parameters
26+
27+
# training params/args
28+
SEED = 1337
29+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" # try also "stabilityai/stablelm-base-alpha-3b"!
30+
BATCH_SIZE = 4
31+
LR_WARMUP_STEPS = 100
32+
MAX_STEPS = 601
33+
34+
35+
def validate(model, val_dataloader):
36+
model.eval()
37+
loss = 0
38+
with torch.no_grad():
39+
for batch in val_dataloader:
40+
input_ids, targets = batch["input_ids"], batch["labels"]
41+
logits = model(input_ids)
42+
logits = logits.reshape(-1, logits.size(-1))
43+
targets = targets.reshape(-1)
44+
loss += nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
45+
fabric.print(f"Validation loss: {loss/len(val_dataloader)}")
46+
47+
48+
def train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader):
49+
for iter_num, batch in enumerate(train_dataloader):
50+
input_ids, targets = batch["input_ids"], batch["labels"]
51+
52+
# get model preds (logits)
53+
logits = model(input_ids)
54+
logits = logits.reshape(-1, logits.size(-1))
55+
56+
# get loss
57+
targets = targets.reshape(-1)
58+
loss = nn.functional.cross_entropy(logits[..., :-1, :], targets[..., 1:])
59+
60+
# update weights
61+
fabric.backward(loss)
62+
optimizer.step()
63+
optimizer.zero_grad()
64+
scheduler.step()
65+
66+
# print train loss every 100 steps
67+
if iter_num % 100 == 0 or iter_num == 0:
68+
fabric.print(f"Train iter {iter_num} - loss {loss}")
69+
70+
# validate every 300 steps
71+
if iter_num % 300 == 0 or iter_num == 0:
72+
validate(model, val_dataloader)
73+
model.train()
74+
iter_num += 1
75+
76+
if iter_num >= MAX_STEPS:
77+
break
78+
79+
80+
def main(fabric):
81+
fabric.seed_everything(SEED)
82+
83+
# setup data, make tokenizer and make dataloaders
84+
data = Alpaca()
85+
tokenizer = Tokenizer(checkpoint_dir=f"checkpoints/{MODEL_NAME}")
86+
data.connect(tokenizer=tokenizer, batch_size=BATCH_SIZE, max_seq_length=1024)
87+
data.setup()
88+
train_dataloader = data.train_dataloader()
89+
val_dataloader = data.val_dataloader()
90+
train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
91+
92+
# print how many steps in an epoch
93+
fabric.print(f"Steps in an epoch: {len(train_dataloader)}")
94+
95+
# setup model
96+
config = Config.from_file(f"checkpoints/{MODEL_NAME}/model_config.yaml")
97+
model = GPT(config)
98+
fabric.print(f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}")
99+
model = fabric.setup(model)
100+
101+
# setup optimizer
102+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.02, betas=(0.9, 0.95))
103+
optimizer = fabric.setup_optimizers(optimizer)
104+
105+
# setup lr scheduler
106+
scheduler1 = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: step / LR_WARMUP_STEPS)
107+
scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(MAX_STEPS - LR_WARMUP_STEPS))
108+
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, [scheduler1, scheduler2], milestones=[LR_WARMUP_STEPS])
109+
110+
# Start training!!!
111+
train(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader)
112+
113+
114+
if __name__ == "__main__":
115+
# check that the model exists (downloaded to ./checkpoints/)
116+
if not os.path.exists(f"checkpoints/{MODEL_NAME}"):
117+
print(f"Model {MODEL_NAME} not found. Please download it using `litgpt download --repo {MODEL_NAME}`")
118+
exit()
119+
120+
### Setup and launch
121+
fabric = L.Fabric(devices="auto", strategy="auto")
122+
fabric.launch(main)

0 commit comments

Comments
 (0)