11
11
import lightning as L
12
12
import torch
13
13
from lightning .fabric .plugins import BitsandbytesPrecision
14
- from lightning .fabric .strategies import FSDPStrategy
14
+ from lightning .fabric .strategies import ModelParallelStrategy
15
15
from lightning .fabric .utilities import ThroughputMonitor
16
16
from lightning_utilities .core .imports import RequirementCache
17
17
from torch .utils .data import ConcatDataset , DataLoader
20
20
from litgpt .args import EvalArgs , LogArgs , TrainArgs
21
21
from litgpt .data import Alpaca , DataModule
22
22
from litgpt .generate .base import generate
23
- from litgpt .lora import GPT , Block , Config , lora_filter , mark_only_lora_as_trainable
23
+ from litgpt .lora import GPT , Block , Config , mark_only_lora_as_trainable
24
24
from litgpt .prompts import save_prompt_style
25
25
from litgpt .scripts .merge_lora import merge_lora
26
26
from litgpt .tokenizer import Tokenizer
@@ -70,6 +70,7 @@ def setup(
70
70
lr_warmup_steps = 100 ,
71
71
epochs = 5 ,
72
72
max_seq_length = None ,
73
+ max_time = None ,
73
74
),
74
75
log : LogArgs = LogArgs (),
75
76
eval : EvalArgs = EvalArgs (interval = 100 , max_new_tokens = 100 , max_iters = 100 ),
@@ -105,6 +106,7 @@ def setup(
105
106
seed: The random seed to use for reproducibility.
106
107
access_token: Optional API token to access models with restrictions.
107
108
"""
109
+
108
110
checkpoint_dir = auto_download_checkpoint (model_name = checkpoint_dir , access_token = access_token )
109
111
pprint (locals ())
110
112
data = Alpaca () if data is None else data
@@ -152,12 +154,10 @@ def setup(
152
154
"Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
153
155
" when using the --quantize flag."
154
156
)
155
- strategy = FSDPStrategy (
156
- auto_wrap_policy = {torch .nn .Linear },
157
- activation_checkpointing_policy = {Block },
158
- state_dict_type = "full" ,
159
- limit_all_gathers = True ,
160
- cpu_offload = False ,
157
+ strategy = ModelParallelStrategy (
158
+ parallelize_fn = parallelize_fn ,
159
+ data_parallel_size = devices * num_nodes ,
160
+ tensor_parallel_size = 1 ,
161
161
)
162
162
else :
163
163
strategy = "auto"
@@ -174,7 +174,9 @@ def setup(
174
174
if torch .cuda .is_available () and devices > 1 :
175
175
check_nvlink_connectivity (fabric )
176
176
177
- fabric .launch (main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes )
177
+ fabric .launch (
178
+ main , devices , seed , config , data , checkpoint_dir , out_dir , train , eval , optimizer , num_nodes , precision
179
+ )
178
180
179
181
180
182
def main (
@@ -189,6 +191,7 @@ def main(
189
191
eval : EvalArgs ,
190
192
optimizer : Union [str , Dict ],
191
193
num_nodes : int = 1 ,
194
+ precision : Optional [str ] = None ,
192
195
) -> None :
193
196
validate_args (train , eval )
194
197
@@ -229,7 +232,6 @@ def main(
229
232
optimizer = fabric .setup_optimizers (optimizer )
230
233
scheduler = get_lr_scheduler (optimizer , warmup_steps = train .lr_warmup_steps , max_steps = lr_max_steps )
231
234
232
- # strict=False because missing keys due to LoRA weights not contained in state dict
233
235
load_checkpoint (fabric , model , checkpoint_path , strict = False )
234
236
235
237
train_time = time .perf_counter ()
@@ -264,12 +266,19 @@ def main(
264
266
save_path = out_dir / "final" / "lit_model.pth.lora"
265
267
save_path .parent .mkdir (parents = True , exist_ok = True )
266
268
save_lora_checkpoint (fabric , model , save_path )
269
+
270
+ fabric .barrier ()
267
271
if fabric .global_rank == 0 :
268
272
# Copy checkpoint files from original checkpoint dir
269
273
copy_config_files (checkpoint_dir , save_path .parent )
270
274
save_hyperparameters (setup , save_path .parent )
271
275
save_prompt_style (data .prompt_style , save_path .parent )
272
- merge_lora (checkpoint_dir = save_path .parent )
276
+ merge_lora (
277
+ checkpoint_dir = save_path .parent ,
278
+ pretrained_checkpoint_dir = checkpoint_dir ,
279
+ precision = precision ,
280
+ )
281
+ fabric .barrier ()
273
282
274
283
275
284
def fit (
@@ -316,6 +325,8 @@ def fit(
316
325
total_lengths = 0
317
326
total_t0 = time .perf_counter ()
318
327
328
+ max_time = train .max_time or float ("inf" )
329
+
319
330
token_counts = {
320
331
"raw_tokens" : torch .tensor (0 , device = fabric .device , dtype = torch .long ),
321
332
"raw_tokens_plus_prompt_template" : torch .tensor (0 , device = fabric .device , dtype = torch .long ),
@@ -327,6 +338,12 @@ def fit(
327
338
iter_t0 = time .perf_counter ()
328
339
batch = next (train_iterator )
329
340
if train_iterator .epoch >= train .epochs :
341
+ generate_example (fabric , model , tokenizer , eval , data )
342
+ fabric .print (f"Number of epochs { train .epochs } reached, stopping training..." )
343
+ break
344
+ if iter_t0 - total_t0 > max_time :
345
+ generate_example (fabric , model , tokenizer , eval , data )
346
+ fabric .print (f"Max time ({ max_time / 60.0 :.2f} m) reached, stopping training..." )
330
347
break
331
348
input_ids , targets = batch ["input_ids" ], batch ["labels" ]
332
349
@@ -497,9 +514,45 @@ def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]:
497
514
return longest_seq_length , longest_seq_ix
498
515
499
516
517
+ def parallelize_fn (model , device_mesh , activation_checkpointing = True ):
518
+ from torch .distributed ._composable .fsdp .fully_shard import fully_shard
519
+ from torch .distributed .algorithms ._checkpoint .checkpoint_wrapper import CheckpointWrapper , checkpoint_wrapper
520
+
521
+ if activation_checkpointing :
522
+ model .transformer .h = torch .nn .ModuleList (
523
+ [checkpoint_wrapper (el , preserve_rng_state = False ) for el in model .transformer .h ]
524
+ )
525
+
526
+ dp_mesh = device_mesh ["data_parallel" ]
527
+
528
+ for m in reversed (list (model .modules ())):
529
+ if (
530
+ (isinstance (m , torch .nn .Linear ) and m .weight .requires_grad )
531
+ or isinstance (m , CheckpointWrapper )
532
+ or isinstance (m , Block )
533
+ ):
534
+ fully_shard (m , mesh = dp_mesh )
535
+
536
+ fully_shard (model , mesh = dp_mesh )
537
+
538
+ return model
539
+
540
+
500
541
def save_lora_checkpoint (fabric : L .Fabric , model : torch .nn .Module , file_path : Path ) -> None :
501
- fabric .print (f"Saving LoRA weights to { str (file_path )!r} " )
502
- fabric .save (file_path , {"model" : model }, filter = {"model" : lora_filter })
542
+ cpu_state_dict = {}
543
+ sharded_sd = model .state_dict ()
544
+ for param_name , param in sharded_sd .items ():
545
+ if "lora_" not in param_name :
546
+ continue
547
+ if param .is_cpu :
548
+ param = param .to (fabric .device )
549
+ if hasattr (param , "_local_tensor" ):
550
+ param = param .full_tensor ()
551
+ if fabric .is_global_zero :
552
+ cpu_state_dict [param_name ] = param .cpu ()
553
+ fabric .barrier ()
554
+ if fabric .is_global_zero :
555
+ torch .save ({"model" : cpu_state_dict }, file_path )
503
556
504
557
505
558
def validate_args (train : TrainArgs , eval : EvalArgs ) -> None :
0 commit comments