Skip to content

Commit 3bb7a62

Browse files
awaelchlirasbt
authored andcommitted
Update adapter scripts (#1121)
1 parent fc12441 commit 3bb7a62

File tree

8 files changed

+129
-52
lines changed

8 files changed

+129
-52
lines changed

litgpt/finetune/adapter.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
1+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22
import dataclasses
3+
import math
34
import os
45
import time
56
from pathlib import Path
@@ -8,16 +9,16 @@
89

910
import lightning as L
1011
import torch
11-
from lightning.fabric.loggers import CSVLogger
1212
from lightning.fabric.plugins import BitsandbytesPrecision
1313
from lightning.fabric.strategies import FSDPStrategy
1414
from lightning.fabric.utilities import ThroughputMonitor
1515
from torch.utils.data import DataLoader
16+
from torchmetrics import RunningMean
1617

17-
from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
1818
from litgpt.args import EvalArgs, TrainArgs
19-
from litgpt.data import Alpaca, DataModule
19+
from litgpt.data import DataModule, Alpaca
2020
from litgpt.generate.base import generate
21+
from litgpt.adapter import GPT, Block, Config, adapter_filter, mark_only_adapter_as_trainable
2122
from litgpt.prompts import save_prompt_style
2223
from litgpt.tokenizer import Tokenizer
2324
from litgpt.utils import (
@@ -31,35 +32,53 @@
3132
parse_devices,
3233
copy_config_files,
3334
save_hyperparameters,
35+
choose_logger,
3436
)
3537

3638

3739
def setup(
40+
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
41+
out_dir: Path = Path("out/finetune/adapter"),
3842
precision: Optional[str] = None,
3943
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
4044
devices: Union[int, str] = 1,
41-
seed: int = 1337,
4245
data: Optional[DataModule] = None,
43-
checkpoint_dir: Path = Path("checkpoints/stabilityai/stablelm-base-alpha-3b"),
44-
out_dir: Path = Path("out/finetune/adapter"),
4546
train: TrainArgs = TrainArgs(
4647
save_interval=1000,
4748
log_interval=1,
48-
global_batch_size=64,
49+
global_batch_size=128,
4950
micro_batch_size=4,
5051
lr_warmup_steps=100,
5152
epochs=5,
5253
learning_rate=1e-3,
5354
max_seq_length=None,
5455
),
55-
eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100),
56+
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
57+
logger_name: Literal["wandb", "tensorboard", "csv"] = "csv",
58+
seed: int = 1337,
5659
) -> None:
60+
"""Finetune a model using the Adapter method.
61+
62+
Arguments:
63+
checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
64+
out_dir: Directory in which to save checkpoints and logs.
65+
precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
66+
quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
67+
devices: How many devices/GPUs to use.
68+
data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
69+
train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
70+
eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
71+
logger_name: The name of the logger to send metrics to.
72+
seed: The random seed to use for reproducibility.
73+
"""
5774

5875
pprint(locals())
5976
data = Alpaca() if data is None else data
6077
devices = parse_devices(devices)
78+
config = Config.from_name(name=checkpoint_dir.name)
6179

6280
precision = precision or get_default_supported_precision(training=True)
81+
logger = choose_logger(logger_name, out_dir, name=f"finetune-{config.name}", log_interval=train.log_interval)
6382

6483
plugins = None
6584
if quantize is not None and quantize.startswith("bnb."):
@@ -85,14 +104,12 @@ def setup(
85104
else:
86105
strategy = "auto"
87106

88-
logger = CSVLogger(out_dir.parent, out_dir.name, flush_logs_every_n_steps=train.log_interval)
89107
fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins)
90-
fabric.launch(main, devices, seed, Config.from_name(name=checkpoint_dir.name), data, checkpoint_dir, out_dir, train, eval)
108+
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval)
91109

92110

93111
def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataModule, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, eval: EvalArgs) -> None:
94112
validate_args(train, eval)
95-
96113
check_valid_checkpoint_dir(checkpoint_dir)
97114

98115
tokenizer = Tokenizer(checkpoint_dir)
@@ -133,12 +150,12 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataMo
133150

134151
train_time = time.perf_counter()
135152
fit(fabric, model, optimizer, scheduler, train_dataloader, val_dataloader, devices, checkpoint_dir, out_dir, train, eval, data)
136-
fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s")
153+
fabric.print(f"Training time: {(time.perf_counter() - train_time):.2f}s")
137154
if fabric.device.type == "cuda":
138155
fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
139156

140-
# Save the final checkpoint at the end of training
141-
save_path = out_dir / "final" / "lit_model.pth"
157+
# Save the final Adapter checkpoint at the end of training
158+
save_path = out_dir / "final" / "lit_model.pth.adapter"
142159
save_path.parent.mkdir(parents=True, exist_ok=True)
143160
save_adapter_checkpoint(fabric, model, save_path)
144161
if fabric.global_rank == 0:
@@ -174,6 +191,9 @@ def fit(
174191

175192
train_iterator = CycleIterator(train_dataloader)
176193
throughput = ThroughputMonitor(fabric, window_size=50)
194+
running_loss = RunningMean(window=train.gradient_accumulation_iters(devices), sync_on_compute=False).to(
195+
fabric.device
196+
)
177197
max_steps = train.max_steps or float("inf")
178198
step_count = 0
179199
iter_num = 0
@@ -184,7 +204,6 @@ def fit(
184204
while step_count < max_steps and train_iterator.epoch < train.epochs:
185205
iter_num += 1
186206
iter_t0 = time.perf_counter()
187-
188207
batch = next(train_iterator)
189208
input_ids, targets = batch["input_ids"], batch["labels"]
190209

@@ -196,6 +215,8 @@ def fit(
196215
loss = chunked_cross_entropy(logits, targets[..., 1:])
197216
fabric.backward(loss / train.gradient_accumulation_iters(devices))
198217

218+
running_loss.update(loss.detach())
219+
199220
if not is_accumulating:
200221
optimizer.step()
201222
optimizer.zero_grad()
@@ -204,30 +225,46 @@ def fit(
204225

205226
total_lengths += input_ids.numel()
206227
if iter_num % train.log_interval == 0:
207-
loss_item = loss.item() # expensive device-to-host synchronization
228+
loss = running_loss.compute().item() # expensive device-to-host synchronization
208229
t1 = time.perf_counter()
209230
throughput.update(
210231
time=t1 - total_t0, batches=iter_num, samples=iter_num * train.micro_batch_size, lengths=total_lengths
211232
)
212233
throughput.compute_and_log(step=iter_num)
234+
metrics = {
235+
"loss": loss,
236+
"iter": iter_num,
237+
"step": step_count,
238+
"epoch": train_iterator.epoch,
239+
"iter_time": t1 - iter_t0,
240+
"tokens": iter_num * train.micro_batch_size * model.config.block_size,
241+
"total_tokens": (
242+
iter_num * train.micro_batch_size * model.config.block_size * fabric.world_size
243+
),
244+
"learning_rate": scheduler.get_last_lr()[0],
245+
}
213246
if isinstance(val_loss, torch.Tensor):
214247
val_loss = f"{val_loss:.3f}"
215248
fabric.print(
216-
f"Epoch {train_iterator.epoch+1} | iter {iter_num} step {step_count} |"
217-
f" loss train: {loss_item:.3f},"
249+
f"Epoch {metrics['epoch'] + 1} | iter {metrics['iter']} step {metrics['step']} |"
250+
f" loss train: {metrics['loss']:.3f},"
218251
f" val: {val_loss} |"
219-
f" iter time: {(t1 - iter_t0) * 1000:.2f} ms"
252+
f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
220253
f"{' (step)' if not is_accumulating else ''}"
221254
)
255+
fabric.log_dict(metrics, step=iter_num)
222256

223257
if not is_accumulating and step_count % eval.interval == 0:
224258
t0 = time.perf_counter()
225259
val_loss = validate(fabric, model, val_dataloader, tokenizer, eval, data)
226260
t1 = time.perf_counter() - t0
227261
fabric.print(f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms")
262+
metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)}
263+
fabric.log_dict(metrics, step=iter_num)
228264
fabric.barrier()
265+
229266
if train.save_interval is not None and not is_accumulating and step_count % train.save_interval == 0:
230-
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth"
267+
checkpoint_file = out_dir / f"step-{step_count:06d}" / "lit_model.pth.adapter"
231268
checkpoint_file.parent.mkdir(parents=True, exist_ok=True)
232269
save_adapter_checkpoint(fabric, model, checkpoint_file)
233270
if fabric.global_rank == 0:
@@ -250,6 +287,7 @@ def validate(
250287
input_ids, targets = batch["input_ids"], batch["labels"]
251288
logits = model(input_ids)
252289
losses[k] = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:], chunk_size=0)
290+
253291
val_loss = losses.mean()
254292

255293
# produce an example:

0 commit comments

Comments
 (0)