Skip to content

Commit 5ac69b0

Browse files
finetune_lora upgrades (#2086)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b00e7b3 commit 5ac69b0

File tree

7 files changed

+849
-15
lines changed

7 files changed

+849
-15
lines changed

litgpt/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from litgpt.finetune.adapter_v2 import setup as finetune_adapter_v2_fn
1313
from litgpt.finetune.full import setup as finetune_full_fn
1414
from litgpt.finetune.lora import setup as finetune_lora_fn
15+
from litgpt.finetune.lora_legacy import setup as finetune_lora_legacy_fn
1516
from litgpt.generate.adapter import main as generate_adapter_fn
1617
from litgpt.generate.adapter_v2 import main as generate_adapter_v2_fn
1718
from litgpt.generate.base import main as generate_base_fn
@@ -35,6 +36,7 @@ def main() -> None:
3536
"chat": chat_fn,
3637
"finetune": finetune_lora_fn,
3738
"finetune_lora": finetune_lora_fn,
39+
"finetune_lora_legacy": finetune_lora_legacy_fn,
3840
"finetune_full": finetune_full_fn,
3941
"finetune_adapter": finetune_adapter_fn,
4042
"finetune_adapter_v2": finetune_adapter_v2_fn,

litgpt/args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class TrainArgs:
2828
"""Total number of tokens to train on"""
2929
max_steps: Optional[int] = None
3030
"""Limits the number of optimizer steps to run"""
31+
max_time: Optional[float] = None
32+
"""Limits the number of seconds to train for"""
3133
max_seq_length: Optional[int] = None
3234
"""Limits the length of samples"""
3335
tie_embeddings: Optional[bool] = None

litgpt/finetune/lora.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import lightning as L
1212
import torch
1313
from lightning.fabric.plugins import BitsandbytesPrecision
14-
from lightning.fabric.strategies import FSDPStrategy
14+
from lightning.fabric.strategies import ModelParallelStrategy
1515
from lightning.fabric.utilities import ThroughputMonitor
1616
from lightning_utilities.core.imports import RequirementCache
1717
from torch.utils.data import ConcatDataset, DataLoader
@@ -20,7 +20,7 @@
2020
from litgpt.args import EvalArgs, LogArgs, TrainArgs
2121
from litgpt.data import Alpaca, DataModule
2222
from litgpt.generate.base import generate
23-
from litgpt.lora import GPT, Block, Config, lora_filter, mark_only_lora_as_trainable
23+
from litgpt.lora import GPT, Block, Config, mark_only_lora_as_trainable
2424
from litgpt.prompts import save_prompt_style
2525
from litgpt.scripts.merge_lora import merge_lora
2626
from litgpt.tokenizer import Tokenizer
@@ -70,6 +70,7 @@ def setup(
7070
lr_warmup_steps=100,
7171
epochs=5,
7272
max_seq_length=None,
73+
max_time=None,
7374
),
7475
log: LogArgs = LogArgs(),
7576
eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100),
@@ -105,6 +106,7 @@ def setup(
105106
seed: The random seed to use for reproducibility.
106107
access_token: Optional API token to access models with restrictions.
107108
"""
109+
108110
checkpoint_dir = auto_download_checkpoint(model_name=checkpoint_dir, access_token=access_token)
109111
pprint(locals())
110112
data = Alpaca() if data is None else data
@@ -152,12 +154,10 @@ def setup(
152154
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
153155
" when using the --quantize flag."
154156
)
155-
strategy = FSDPStrategy(
156-
auto_wrap_policy={torch.nn.Linear},
157-
activation_checkpointing_policy={Block},
158-
state_dict_type="full",
159-
limit_all_gathers=True,
160-
cpu_offload=False,
157+
strategy = ModelParallelStrategy(
158+
parallelize_fn=parallelize_fn,
159+
data_parallel_size=devices * num_nodes,
160+
tensor_parallel_size=1,
161161
)
162162
else:
163163
strategy = "auto"
@@ -174,7 +174,9 @@ def setup(
174174
if torch.cuda.is_available() and devices > 1:
175175
check_nvlink_connectivity(fabric)
176176

177-
fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes)
177+
fabric.launch(
178+
main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval, optimizer, num_nodes, precision
179+
)
178180

179181

180182
def main(
@@ -189,6 +191,7 @@ def main(
189191
eval: EvalArgs,
190192
optimizer: Union[str, Dict],
191193
num_nodes: int = 1,
194+
precision: Optional[str] = None,
192195
) -> None:
193196
validate_args(train, eval)
194197

@@ -229,7 +232,6 @@ def main(
229232
optimizer = fabric.setup_optimizers(optimizer)
230233
scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps)
231234

232-
# strict=False because missing keys due to LoRA weights not contained in state dict
233235
load_checkpoint(fabric, model, checkpoint_path, strict=False)
234236

235237
train_time = time.perf_counter()
@@ -264,12 +266,19 @@ def main(
264266
save_path = out_dir / "final" / "lit_model.pth.lora"
265267
save_path.parent.mkdir(parents=True, exist_ok=True)
266268
save_lora_checkpoint(fabric, model, save_path)
269+
270+
fabric.barrier()
267271
if fabric.global_rank == 0:
268272
# Copy checkpoint files from original checkpoint dir
269273
copy_config_files(checkpoint_dir, save_path.parent)
270274
save_hyperparameters(setup, save_path.parent)
271275
save_prompt_style(data.prompt_style, save_path.parent)
272-
merge_lora(checkpoint_dir=save_path.parent)
276+
merge_lora(
277+
checkpoint_dir=save_path.parent,
278+
pretrained_checkpoint_dir=checkpoint_dir,
279+
precision=precision,
280+
)
281+
fabric.barrier()
273282

274283

275284
def fit(
@@ -316,6 +325,8 @@ def fit(
316325
total_lengths = 0
317326
total_t0 = time.perf_counter()
318327

328+
max_time = train.max_time or float("inf")
329+
319330
token_counts = {
320331
"raw_tokens": torch.tensor(0, device=fabric.device, dtype=torch.long),
321332
"raw_tokens_plus_prompt_template": torch.tensor(0, device=fabric.device, dtype=torch.long),
@@ -327,6 +338,12 @@ def fit(
327338
iter_t0 = time.perf_counter()
328339
batch = next(train_iterator)
329340
if train_iterator.epoch >= train.epochs:
341+
generate_example(fabric, model, tokenizer, eval, data)
342+
fabric.print(f"Number of epochs {train.epochs} reached, stopping training...")
343+
break
344+
if iter_t0 - total_t0 > max_time:
345+
generate_example(fabric, model, tokenizer, eval, data)
346+
fabric.print(f"Max time ({max_time / 60.0:.2f}m) reached, stopping training...")
330347
break
331348
input_ids, targets = batch["input_ids"], batch["labels"]
332349

@@ -497,9 +514,45 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
497514
return longest_seq_length, longest_seq_ix
498515

499516

517+
def parallelize_fn(model, device_mesh, activation_checkpointing=True):
518+
from torch.distributed._composable.fsdp.fully_shard import fully_shard
519+
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper, checkpoint_wrapper
520+
521+
if activation_checkpointing:
522+
model.transformer.h = torch.nn.ModuleList(
523+
[checkpoint_wrapper(el, preserve_rng_state=False) for el in model.transformer.h]
524+
)
525+
526+
dp_mesh = device_mesh["data_parallel"]
527+
528+
for m in reversed(list(model.modules())):
529+
if (
530+
(isinstance(m, torch.nn.Linear) and m.weight.requires_grad)
531+
or isinstance(m, CheckpointWrapper)
532+
or isinstance(m, Block)
533+
):
534+
fully_shard(m, mesh=dp_mesh)
535+
536+
fully_shard(model, mesh=dp_mesh)
537+
538+
return model
539+
540+
500541
def save_lora_checkpoint(fabric: L.Fabric, model: torch.nn.Module, file_path: Path) -> None:
501-
fabric.print(f"Saving LoRA weights to {str(file_path)!r}")
502-
fabric.save(file_path, {"model": model}, filter={"model": lora_filter})
542+
cpu_state_dict = {}
543+
sharded_sd = model.state_dict()
544+
for param_name, param in sharded_sd.items():
545+
if "lora_" not in param_name:
546+
continue
547+
if param.is_cpu:
548+
param = param.to(fabric.device)
549+
if hasattr(param, "_local_tensor"):
550+
param = param.full_tensor()
551+
if fabric.is_global_zero:
552+
cpu_state_dict[param_name] = param.cpu()
553+
fabric.barrier()
554+
if fabric.is_global_zero:
555+
torch.save({"model": cpu_state_dict}, file_path)
503556

504557

505558
def validate_args(train: TrainArgs, eval: EvalArgs) -> None:

0 commit comments

Comments
 (0)