21
21
from colossal_llama .utils .froze import freeze_non_embeds_parameters
22
22
from colossal_llama .utils .neftune_patch import activate_neftune , deactivate_neftune
23
23
from colossal_llama .utils .utils import all_reduce_mean , format_numel_str , get_model_numel
24
+ from peft import LoraConfig
24
25
from torch .utils .tensorboard import SummaryWriter
25
26
from tqdm import tqdm
26
27
from transformers import AutoModelForCausalLM , AutoTokenizer
@@ -65,7 +66,7 @@ def train(args) -> None:
65
66
initial_scale = 2 ** 16 ,
66
67
max_norm = args .grad_clip ,
67
68
enable_gradient_accumulation = (args .accumulation_steps > 1 ),
68
- enable_fused_normalization = torch . cuda .is_available (),
69
+ enable_fused_normalization = get_accelerator () .is_available (),
69
70
enable_flash_attention = args .use_flash_attn ,
70
71
)
71
72
elif args .plugin == "gemini_auto" :
@@ -75,7 +76,7 @@ def train(args) -> None:
75
76
initial_scale = 2 ** 16 ,
76
77
max_norm = args .grad_clip ,
77
78
enable_gradient_accumulation = (args .accumulation_steps > 1 ),
78
- enable_fused_normalization = torch . cuda .is_available (),
79
+ enable_fused_normalization = get_accelerator () .is_available (),
79
80
enable_flash_attention = args .use_flash_attn ,
80
81
)
81
82
elif args .plugin == "zero2" :
@@ -101,10 +102,9 @@ def train(args) -> None:
101
102
sequence_parallelism_mode = args .sp_mode ,
102
103
zero_stage = args .zero_stage ,
103
104
enable_flash_attention = args .use_flash_attn ,
104
- enable_fused_normalization = torch . cuda .is_available (),
105
+ enable_fused_normalization = get_accelerator () .is_available (),
105
106
enable_sequence_parallelism = args .enable_sequence_parallelism ,
106
107
cpu_offload = True if args .zero_stage >= 1 and args .zero_cpu_offload else False ,
107
- parallel_output = False ,
108
108
max_norm = args .grad_clip ,
109
109
precision = args .mixed_precision ,
110
110
microbatch_size = args .microbatch_size ,
@@ -117,11 +117,17 @@ def train(args) -> None:
117
117
# ======================================================
118
118
# Initialize Tokenizer, Dataset, Collator and Dataloader
119
119
# ======================================================
120
- tokenizer = AutoTokenizer .from_pretrained (args .pretrained )
120
+ tokenizer = AutoTokenizer .from_pretrained (args .pretrained , trust_remote_code = True )
121
121
if args .pad_token == "eos" :
122
- tokenizer .pad_token = tokenizer .eos_token
122
+ try :
123
+ tokenizer .pad_token = tokenizer .eos_token
124
+ except AttributeError :
125
+ coordinator .print_on_master (f"pad_token can't be set" )
123
126
elif args .pad_token == "unk" :
124
- tokenizer .pad_token = tokenizer .unk_token
127
+ try :
128
+ tokenizer .pad_token = tokenizer .unk_token
129
+ except AttributeError :
130
+ coordinator .print_on_master (f"pad_token can't be set" )
125
131
tokenizer .add_bos_token = False
126
132
tokenizer .add_eos_token = False
127
133
@@ -164,33 +170,31 @@ def train(args) -> None:
164
170
# ======================================================
165
171
# Initialize Model, Objective, Optimizer and LR Scheduler
166
172
# ======================================================
173
+ # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
167
174
init_ctx = (
168
175
LazyInitContext (default_device = get_current_device ())
169
- if isinstance (plugin , (GeminiPlugin , HybridParallelPlugin ))
176
+ if isinstance (plugin , (GeminiPlugin , HybridParallelPlugin )) and args . lora_rank == 0
170
177
else nullcontext ()
171
178
)
172
179
with init_ctx :
173
- if args .use_flash_attn :
174
- model = AutoModelForCausalLM .from_pretrained (
175
- args .pretrained ,
176
- attn_implementation = "flash_attention_2" ,
177
- torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
178
- trust_remote_code = True ,
179
- )
180
- else :
181
- model = AutoModelForCausalLM .from_pretrained (
182
- args .pretrained ,
183
- torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
184
- trust_remote_code = True ,
185
- )
180
+ model = AutoModelForCausalLM .from_pretrained (
181
+ args .pretrained ,
182
+ torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
183
+ trust_remote_code = True ,
184
+ )
186
185
# Freeze part of parameters.
187
186
if args .freeze_non_embeds_params :
188
187
freeze_non_embeds_parameters (model = model )
188
+
189
+ if args .lora_rank > 0 :
190
+ lora_config = LoraConfig (task_type = "CAUSAL_LM" , r = args .lora_rank , lora_alpha = 32 , lora_dropout = 0.1 )
191
+ model = booster .enable_lora (model , lora_config = lora_config )
192
+
189
193
# this is essential, otherwise the grad checkpoint will not work.
190
194
model .train ()
191
195
192
196
if args .use_grad_checkpoint :
193
- model .gradient_checkpointing_enable ()
197
+ model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = { "use_reentrant" : False } )
194
198
coordinator .print_on_master (msg = "Gradient checkpointing enabled successfully" )
195
199
196
200
model_numel = get_model_numel (model )
@@ -327,6 +331,7 @@ def train(args) -> None:
327
331
step = step + 1 ,
328
332
batch_size = args .batch_size ,
329
333
coordinator = coordinator ,
334
+ use_lora = (args .lora_rank > 0 ),
330
335
)
331
336
coordinator .print_on_master (
332
337
f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
@@ -371,44 +376,45 @@ def train(args) -> None:
371
376
total_loss .fill_ (0.0 )
372
377
pbar .update ()
373
378
374
- # Save modeling.
375
- save_model_condition = (
376
- args .save_interval > 0 and (step + 1 ) % (args .save_interval * args .accumulation_steps ) == 0
377
- )
379
+ # Save modeling.
380
+ save_model_condition = (
381
+ args .save_interval > 0 and (step + 1 ) % (args .save_interval * args .accumulation_steps ) == 0
382
+ )
378
383
379
- if not args .skip_save_each_epoch :
380
- save_model_condition = save_model_condition or (step + 1 ) == len (dataloader )
384
+ if not args .skip_save_each_epoch :
385
+ save_model_condition = save_model_condition or (step + 1 ) == len (dataloader )
381
386
382
- if save_model_condition and not args .benchmark :
383
- coordinator .print_on_master ("\n Start saving model checkpoint with running states" )
387
+ if save_model_condition and not args .benchmark :
388
+ coordinator .print_on_master ("\n Start saving model checkpoint with running states" )
384
389
385
- if args .use_neft :
386
- coordinator .print_on_master ("Deactivate NEFTune before saving model." )
387
- deactivate_neftune (model , handle )
390
+ if args .use_neft :
391
+ coordinator .print_on_master ("Deactivate NEFTune before saving model." )
392
+ deactivate_neftune (model , handle )
388
393
389
- accelerator .empty_cache ()
390
- save_checkpoint (
391
- save_dir = args .save_dir ,
392
- booster = booster ,
393
- model = model ,
394
- optimizer = optimizer ,
395
- lr_scheduler = lr_scheduler ,
396
- epoch = epoch ,
397
- step = step + 1 ,
398
- batch_size = args .batch_size ,
399
- coordinator = coordinator ,
400
- )
401
- coordinator .print_on_master (
402
- f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
403
- )
394
+ accelerator .empty_cache ()
395
+ save_checkpoint (
396
+ save_dir = args .save_dir ,
397
+ booster = booster ,
398
+ model = model ,
399
+ optimizer = optimizer ,
400
+ lr_scheduler = lr_scheduler ,
401
+ epoch = epoch ,
402
+ step = step + 1 ,
403
+ batch_size = args .batch_size ,
404
+ coordinator = coordinator ,
405
+ use_lora = (args .lora_rank > 0 ),
406
+ )
407
+ coordinator .print_on_master (
408
+ f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
409
+ )
404
410
405
- if args .use_neft :
406
- coordinator .print_on_master ("Activate NEFTune." )
407
- model , handle = activate_neftune (model )
411
+ if args .use_neft :
412
+ coordinator .print_on_master ("Activate NEFTune." )
413
+ model , handle = activate_neftune (model )
408
414
409
- # Delete cache.
410
- # del batch, batch_labels, batch_output, loss
411
- accelerator .empty_cache ()
415
+ # Delete cache.
416
+ # del batch, batch_labels, batch_output, loss
417
+ accelerator .empty_cache ()
412
418
413
419
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
414
420
dataloader .sampler .set_start_index (start_index = 0 )
@@ -522,6 +528,7 @@ def train(args) -> None:
522
528
parser .add_argument (
523
529
"--microbatch_size" , type = int , default = 1 , help = "Batch size for each process in PP, used for 3d plugin."
524
530
)
531
+ parser .add_argument ("--lora_rank" , type = int , default = 0 , help = "lora rank when using lora to train." )
525
532
526
533
# Additional arguments for benchmark.
527
534
parser .add_argument ("--num_samples" , type = int , default = 500 , help = "Number of samples for benchmarking." )
0 commit comments