2121from colossal_llama .utils .froze import freeze_non_embeds_parameters
2222from colossal_llama .utils .neftune_patch import activate_neftune , deactivate_neftune
2323from colossal_llama .utils .utils import all_reduce_mean , format_numel_str , get_model_numel
24+ from peft import LoraConfig
2425from torch .utils .tensorboard import SummaryWriter
2526from tqdm import tqdm
2627from transformers import AutoModelForCausalLM , AutoTokenizer
@@ -65,7 +66,7 @@ def train(args) -> None:
6566 initial_scale = 2 ** 16 ,
6667 max_norm = args .grad_clip ,
6768 enable_gradient_accumulation = (args .accumulation_steps > 1 ),
68- enable_fused_normalization = torch . cuda .is_available (),
69+ enable_fused_normalization = get_accelerator () .is_available (),
6970 enable_flash_attention = args .use_flash_attn ,
7071 )
7172 elif args .plugin == "gemini_auto" :
@@ -75,7 +76,7 @@ def train(args) -> None:
7576 initial_scale = 2 ** 16 ,
7677 max_norm = args .grad_clip ,
7778 enable_gradient_accumulation = (args .accumulation_steps > 1 ),
78- enable_fused_normalization = torch . cuda .is_available (),
79+ enable_fused_normalization = get_accelerator () .is_available (),
7980 enable_flash_attention = args .use_flash_attn ,
8081 )
8182 elif args .plugin == "zero2" :
@@ -101,10 +102,9 @@ def train(args) -> None:
101102 sequence_parallelism_mode = args .sp_mode ,
102103 zero_stage = args .zero_stage ,
103104 enable_flash_attention = args .use_flash_attn ,
104- enable_fused_normalization = torch . cuda .is_available (),
105+ enable_fused_normalization = get_accelerator () .is_available (),
105106 enable_sequence_parallelism = args .enable_sequence_parallelism ,
106107 cpu_offload = True if args .zero_stage >= 1 and args .zero_cpu_offload else False ,
107- parallel_output = False ,
108108 max_norm = args .grad_clip ,
109109 precision = args .mixed_precision ,
110110 microbatch_size = args .microbatch_size ,
@@ -117,11 +117,17 @@ def train(args) -> None:
117117 # ======================================================
118118 # Initialize Tokenizer, Dataset, Collator and Dataloader
119119 # ======================================================
120- tokenizer = AutoTokenizer .from_pretrained (args .pretrained )
120+ tokenizer = AutoTokenizer .from_pretrained (args .pretrained , trust_remote_code = True )
121121 if args .pad_token == "eos" :
122- tokenizer .pad_token = tokenizer .eos_token
122+ try :
123+ tokenizer .pad_token = tokenizer .eos_token
124+ except AttributeError :
125+ coordinator .print_on_master (f"pad_token can't be set" )
123126 elif args .pad_token == "unk" :
124- tokenizer .pad_token = tokenizer .unk_token
127+ try :
128+ tokenizer .pad_token = tokenizer .unk_token
129+ except AttributeError :
130+ coordinator .print_on_master (f"pad_token can't be set" )
125131 tokenizer .add_bos_token = False
126132 tokenizer .add_eos_token = False
127133
@@ -164,33 +170,31 @@ def train(args) -> None:
164170 # ======================================================
165171 # Initialize Model, Objective, Optimizer and LR Scheduler
166172 # ======================================================
173+ # When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
167174 init_ctx = (
168175 LazyInitContext (default_device = get_current_device ())
169- if isinstance (plugin , (GeminiPlugin , HybridParallelPlugin ))
176+ if isinstance (plugin , (GeminiPlugin , HybridParallelPlugin )) and args . lora_rank == 0
170177 else nullcontext ()
171178 )
172179 with init_ctx :
173- if args .use_flash_attn :
174- model = AutoModelForCausalLM .from_pretrained (
175- args .pretrained ,
176- attn_implementation = "flash_attention_2" ,
177- torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
178- trust_remote_code = True ,
179- )
180- else :
181- model = AutoModelForCausalLM .from_pretrained (
182- args .pretrained ,
183- torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
184- trust_remote_code = True ,
185- )
180+ model = AutoModelForCausalLM .from_pretrained (
181+ args .pretrained ,
182+ torch_dtype = torch .bfloat16 if args .mixed_precision == "bf16" else torch .float16 ,
183+ trust_remote_code = True ,
184+ )
186185 # Freeze part of parameters.
187186 if args .freeze_non_embeds_params :
188187 freeze_non_embeds_parameters (model = model )
188+
189+ if args .lora_rank > 0 :
190+ lora_config = LoraConfig (task_type = "CAUSAL_LM" , r = args .lora_rank , lora_alpha = 32 , lora_dropout = 0.1 )
191+ model = booster .enable_lora (model , lora_config = lora_config )
192+
189193 # this is essential, otherwise the grad checkpoint will not work.
190194 model .train ()
191195
192196 if args .use_grad_checkpoint :
193- model .gradient_checkpointing_enable ()
197+ model .gradient_checkpointing_enable (gradient_checkpointing_kwargs = { "use_reentrant" : False } )
194198 coordinator .print_on_master (msg = "Gradient checkpointing enabled successfully" )
195199
196200 model_numel = get_model_numel (model )
@@ -327,6 +331,7 @@ def train(args) -> None:
327331 step = step + 1 ,
328332 batch_size = args .batch_size ,
329333 coordinator = coordinator ,
334+ use_lora = (args .lora_rank > 0 ),
330335 )
331336 coordinator .print_on_master (
332337 f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
@@ -371,44 +376,45 @@ def train(args) -> None:
371376 total_loss .fill_ (0.0 )
372377 pbar .update ()
373378
374- # Save modeling.
375- save_model_condition = (
376- args .save_interval > 0 and (step + 1 ) % (args .save_interval * args .accumulation_steps ) == 0
377- )
379+ # Save modeling.
380+ save_model_condition = (
381+ args .save_interval > 0 and (step + 1 ) % (args .save_interval * args .accumulation_steps ) == 0
382+ )
378383
379- if not args .skip_save_each_epoch :
380- save_model_condition = save_model_condition or (step + 1 ) == len (dataloader )
384+ if not args .skip_save_each_epoch :
385+ save_model_condition = save_model_condition or (step + 1 ) == len (dataloader )
381386
382- if save_model_condition and not args .benchmark :
383- coordinator .print_on_master ("\n Start saving model checkpoint with running states" )
387+ if save_model_condition and not args .benchmark :
388+ coordinator .print_on_master ("\n Start saving model checkpoint with running states" )
384389
385- if args .use_neft :
386- coordinator .print_on_master ("Deactivate NEFTune before saving model." )
387- deactivate_neftune (model , handle )
390+ if args .use_neft :
391+ coordinator .print_on_master ("Deactivate NEFTune before saving model." )
392+ deactivate_neftune (model , handle )
388393
389- accelerator .empty_cache ()
390- save_checkpoint (
391- save_dir = args .save_dir ,
392- booster = booster ,
393- model = model ,
394- optimizer = optimizer ,
395- lr_scheduler = lr_scheduler ,
396- epoch = epoch ,
397- step = step + 1 ,
398- batch_size = args .batch_size ,
399- coordinator = coordinator ,
400- )
401- coordinator .print_on_master (
402- f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
403- )
394+ accelerator .empty_cache ()
395+ save_checkpoint (
396+ save_dir = args .save_dir ,
397+ booster = booster ,
398+ model = model ,
399+ optimizer = optimizer ,
400+ lr_scheduler = lr_scheduler ,
401+ epoch = epoch ,
402+ step = step + 1 ,
403+ batch_size = args .batch_size ,
404+ coordinator = coordinator ,
405+ use_lora = (args .lora_rank > 0 ),
406+ )
407+ coordinator .print_on_master (
408+ f"Saved checkpoint at epoch { epoch } step { step + 1 } at folder { args .save_dir } "
409+ )
404410
405- if args .use_neft :
406- coordinator .print_on_master ("Activate NEFTune." )
407- model , handle = activate_neftune (model )
411+ if args .use_neft :
412+ coordinator .print_on_master ("Activate NEFTune." )
413+ model , handle = activate_neftune (model )
408414
409- # Delete cache.
410- # del batch, batch_labels, batch_output, loss
411- accelerator .empty_cache ()
415+ # Delete cache.
416+ # del batch, batch_labels, batch_output, loss
417+ accelerator .empty_cache ()
412418
413419 # the continue epochs are not resumed, so we need to reset the sampler start index and start step
414420 dataloader .sampler .set_start_index (start_index = 0 )
@@ -522,6 +528,7 @@ def train(args) -> None:
522528 parser .add_argument (
523529 "--microbatch_size" , type = int , default = 1 , help = "Batch size for each process in PP, used for 3d plugin."
524530 )
531+ parser .add_argument ("--lora_rank" , type = int , default = 0 , help = "lora rank when using lora to train." )
525532
526533 # Additional arguments for benchmark.
527534 parser .add_argument ("--num_samples" , type = int , default = 500 , help = "Number of samples for benchmarking." )
0 commit comments