1- # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
1+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22import dataclasses
3+ import math
34import os
45import time
56from pathlib import Path
89
910import lightning as L
1011import torch
11- from lightning .fabric .loggers import CSVLogger
1212from lightning .fabric .plugins import BitsandbytesPrecision
1313from lightning .fabric .strategies import FSDPStrategy
1414from lightning .fabric .utilities import ThroughputMonitor
1515from torch .utils .data import DataLoader
16+ from torchmetrics import RunningMean
1617
17- from litgpt .adapter import GPT , Block , Config , adapter_filter , mark_only_adapter_as_trainable
1818from litgpt .args import EvalArgs , TrainArgs
19- from litgpt .data import Alpaca , DataModule
19+ from litgpt .data import DataModule , Alpaca
2020from litgpt .generate .base import generate
21+ from litgpt .adapter import GPT , Block , Config , adapter_filter , mark_only_adapter_as_trainable
2122from litgpt .prompts import save_prompt_style
2223from litgpt .tokenizer import Tokenizer
2324from litgpt .utils import (
3132 parse_devices ,
3233 copy_config_files ,
3334 save_hyperparameters ,
35+ choose_logger ,
3436)
3537
3638
3739def setup (
40+ checkpoint_dir : Path = Path ("checkpoints/stabilityai/stablelm-base-alpha-3b" ),
41+ out_dir : Path = Path ("out/finetune/adapter" ),
3842 precision : Optional [str ] = None ,
3943 quantize : Optional [Literal ["bnb.nf4" , "bnb.nf4-dq" , "bnb.fp4" , "bnb.fp4-dq" , "bnb.int8-training" ]] = None ,
4044 devices : Union [int , str ] = 1 ,
41- seed : int = 1337 ,
4245 data : Optional [DataModule ] = None ,
43- checkpoint_dir : Path = Path ("checkpoints/stabilityai/stablelm-base-alpha-3b" ),
44- out_dir : Path = Path ("out/finetune/adapter" ),
4546 train : TrainArgs = TrainArgs (
4647 save_interval = 1000 ,
4748 log_interval = 1 ,
48- global_batch_size = 64 ,
49+ global_batch_size = 128 ,
4950 micro_batch_size = 4 ,
5051 lr_warmup_steps = 100 ,
5152 epochs = 5 ,
5253 learning_rate = 1e-3 ,
5354 max_seq_length = None ,
5455 ),
55- eval : EvalArgs = EvalArgs (interval = 600 , max_new_tokens = 100 , max_iters = 100 ),
56+ eval : EvalArgs = EvalArgs (interval = 100 , max_new_tokens = 100 , max_iters = 100 ),
57+ logger_name : Literal ["wandb" , "tensorboard" , "csv" ] = "csv" ,
58+ seed : int = 1337 ,
5659) -> None :
60+ """Finetune a model using the Adapter method.
61+
62+ Arguments:
63+ checkpoint_dir: The path to the base model's checkpoint directory to load for finetuning.
64+ out_dir: Directory in which to save checkpoints and logs.
65+ precision: The precision to use for finetuning. Possible choices: "bf16-true", "bf16-mixed", "32-true".
66+ quantize: If set, quantize the model with this algorithm. See ``tutorials/quantize.md`` for more information.
67+ devices: How many devices/GPUs to use.
68+ data: Data-related arguments. If not provided, the default is ``litgpt.data.Alpaca``.
69+ train: Training-related arguments. See ``litgpt.args.TrainArgs`` for details.
70+ eval: Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details.
71+ logger_name: The name of the logger to send metrics to.
72+ seed: The random seed to use for reproducibility.
73+ """
5774
5875 pprint (locals ())
5976 data = Alpaca () if data is None else data
6077 devices = parse_devices (devices )
78+ config = Config .from_name (name = checkpoint_dir .name )
6179
6280 precision = precision or get_default_supported_precision (training = True )
81+ logger = choose_logger (logger_name , out_dir , name = f"finetune-{ config .name } " , log_interval = train .log_interval )
6382
6483 plugins = None
6584 if quantize is not None and quantize .startswith ("bnb." ):
@@ -85,14 +104,12 @@ def setup(
85104 else :
86105 strategy = "auto"
87106
88- logger = CSVLogger (out_dir .parent , out_dir .name , flush_logs_every_n_steps = train .log_interval )
89107 fabric = L .Fabric (devices = devices , strategy = strategy , precision = precision , loggers = logger , plugins = plugins )
90- fabric .launch (main , devices , seed , Config . from_name ( name = checkpoint_dir . name ) , data , checkpoint_dir , out_dir , train , eval )
108+ fabric .launch (main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval )
91109
92110
93111def main (fabric : L .Fabric , devices : int , seed : int , config : Config , data : DataModule , checkpoint_dir : Path , out_dir : Path , train : TrainArgs , eval : EvalArgs ) -> None :
94112 validate_args (train , eval )
95-
96113 check_valid_checkpoint_dir (checkpoint_dir )
97114
98115 tokenizer = Tokenizer (checkpoint_dir )
@@ -133,12 +150,12 @@ def main(fabric: L.Fabric, devices: int, seed: int, config: Config, data: DataMo
133150
134151 train_time = time .perf_counter ()
135152 fit (fabric , model , optimizer , scheduler , train_dataloader , val_dataloader , devices , checkpoint_dir , out_dir , train , eval , data )
136- fabric .print (f"Training time: { (time .perf_counter ()- train_time ):.2f} s" )
153+ fabric .print (f"Training time: { (time .perf_counter () - train_time ):.2f} s" )
137154 if fabric .device .type == "cuda" :
138155 fabric .print (f"Memory used: { torch .cuda .max_memory_allocated () / 1e9 :.02f} GB" )
139156
140- # Save the final checkpoint at the end of training
141- save_path = out_dir / "final" / "lit_model.pth"
157+ # Save the final Adapter checkpoint at the end of training
158+ save_path = out_dir / "final" / "lit_model.pth.adapter "
142159 save_path .parent .mkdir (parents = True , exist_ok = True )
143160 save_adapter_checkpoint (fabric , model , save_path )
144161 if fabric .global_rank == 0 :
@@ -174,6 +191,9 @@ def fit(
174191
175192 train_iterator = CycleIterator (train_dataloader )
176193 throughput = ThroughputMonitor (fabric , window_size = 50 )
194+ running_loss = RunningMean (window = train .gradient_accumulation_iters (devices ), sync_on_compute = False ).to (
195+ fabric .device
196+ )
177197 max_steps = train .max_steps or float ("inf" )
178198 step_count = 0
179199 iter_num = 0
@@ -184,7 +204,6 @@ def fit(
184204 while step_count < max_steps and train_iterator .epoch < train .epochs :
185205 iter_num += 1
186206 iter_t0 = time .perf_counter ()
187-
188207 batch = next (train_iterator )
189208 input_ids , targets = batch ["input_ids" ], batch ["labels" ]
190209
@@ -196,6 +215,8 @@ def fit(
196215 loss = chunked_cross_entropy (logits , targets [..., 1 :])
197216 fabric .backward (loss / train .gradient_accumulation_iters (devices ))
198217
218+ running_loss .update (loss .detach ())
219+
199220 if not is_accumulating :
200221 optimizer .step ()
201222 optimizer .zero_grad ()
@@ -204,30 +225,46 @@ def fit(
204225
205226 total_lengths += input_ids .numel ()
206227 if iter_num % train .log_interval == 0 :
207- loss_item = loss .item () # expensive device-to-host synchronization
228+ loss = running_loss . compute () .item () # expensive device-to-host synchronization
208229 t1 = time .perf_counter ()
209230 throughput .update (
210231 time = t1 - total_t0 , batches = iter_num , samples = iter_num * train .micro_batch_size , lengths = total_lengths
211232 )
212233 throughput .compute_and_log (step = iter_num )
234+ metrics = {
235+ "loss" : loss ,
236+ "iter" : iter_num ,
237+ "step" : step_count ,
238+ "epoch" : train_iterator .epoch ,
239+ "iter_time" : t1 - iter_t0 ,
240+ "tokens" : iter_num * train .micro_batch_size * model .config .block_size ,
241+ "total_tokens" : (
242+ iter_num * train .micro_batch_size * model .config .block_size * fabric .world_size
243+ ),
244+ "learning_rate" : scheduler .get_last_lr ()[0 ],
245+ }
213246 if isinstance (val_loss , torch .Tensor ):
214247 val_loss = f"{ val_loss :.3f} "
215248 fabric .print (
216- f"Epoch { train_iterator . epoch + 1 } | iter { iter_num } step { step_count } |"
217- f" loss train: { loss_item :.3f} ,"
249+ f"Epoch { metrics [ ' epoch' ] + 1 } | iter { metrics [ 'iter' ] } step { metrics [ 'step' ] } |"
250+ f" loss train: { metrics [ 'loss' ] :.3f} ,"
218251 f" val: { val_loss } |"
219- f" iter time: { ( t1 - iter_t0 ) * 1000 :.2f} ms"
252+ f" iter time: { metrics [ 'iter_time' ] * 1000 :.2f} ms"
220253 f"{ ' (step)' if not is_accumulating else '' } "
221254 )
255+ fabric .log_dict (metrics , step = iter_num )
222256
223257 if not is_accumulating and step_count % eval .interval == 0 :
224258 t0 = time .perf_counter ()
225259 val_loss = validate (fabric , model , val_dataloader , tokenizer , eval , data )
226260 t1 = time .perf_counter () - t0
227261 fabric .print (f"iter { iter_num } : val loss { val_loss .item ():.4f} , val time: { t1 * 1000 :.2f} ms" )
262+ metrics = {"val_loss" : val_loss , "val_ppl" : math .exp (val_loss )}
263+ fabric .log_dict (metrics , step = iter_num )
228264 fabric .barrier ()
265+
229266 if train .save_interval is not None and not is_accumulating and step_count % train .save_interval == 0 :
230- checkpoint_file = out_dir / f"step-{ step_count :06d} " / "lit_model.pth"
267+ checkpoint_file = out_dir / f"step-{ step_count :06d} " / "lit_model.pth.adapter "
231268 checkpoint_file .parent .mkdir (parents = True , exist_ok = True )
232269 save_adapter_checkpoint (fabric , model , checkpoint_file )
233270 if fabric .global_rank == 0 :
@@ -250,6 +287,7 @@ def validate(
250287 input_ids , targets = batch ["input_ids" ], batch ["labels" ]
251288 logits = model (input_ids )
252289 losses [k ] = chunked_cross_entropy (logits [..., :- 1 , :], targets [..., 1 :], chunk_size = 0 )
290+
253291 val_loss = losses .mean ()
254292
255293 # produce an example:
0 commit comments