@@ -88,8 +88,11 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
8888 self ._model_service = model_service
8989 self ._model_name = model_service .model_name
9090 self ._model_pack_path = model_service ._model_pack_path
91- self ._retrained_models_dir = os .path .join (model_service ._model_parent_dir , "retrained" ,
92- self ._model_name .replace (" " , "_" ))
91+ self ._retrained_models_dir = os .path .join (
92+ model_service ._model_parent_dir ,
93+ "retrained" ,
94+ self ._model_name .replace (" " , "_" ),
95+ )
9396 self ._model_manager = ModelManager (type (model_service ), model_service ._config )
9497 self ._max_length = model_service .model .config .max_position_embeddings
9598 os .makedirs (self ._retrained_models_dir , exist_ok = True )
@@ -306,7 +309,7 @@ def run(
306309 logger .error ("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`." )
307310 raise ExtraDependencyRequiredException ("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`." )
308311
309- copied_model_pack_path = None
312+ trained_model_pack_path = None
310313 redeploy = self ._config .REDEPLOY_TRAINED_MODEL == "true"
311314 skip_save_model = self ._config .SKIP_SAVE_MODEL == "true"
312315 results_path = os .path .abspath (os .path .join (self ._config .TRAINING_CACHE_DIR , "results" ))
@@ -319,15 +322,16 @@ def run(
319322
320323 if not eval_mode :
321324 try :
322- logger .info ("Loading a new model copy for training..." )
323- copied_model_pack_path = self . _make_model_file_copy (self ._model_pack_path , run_id )
324- model , tokenizer = self ._model_service . load_model (
325- copied_model_pack_path ,
326- load_in_4bit = True , # for memory efficient training
325+ logger .info ("Loading a PEFT model for training..." )
326+ model_pack_file_ext = get_model_data_package_extension (self ._model_pack_path )
327+ trained_model_pack_path = self ._model_pack_path . replace (
328+ model_pack_file_ext ,
329+ f"_trained_ { run_id } { model_pack_file_ext } " ,
327330 )
328- copied_model_directory = os .path .join (
329- os .path .dirname (copied_model_pack_path ),
330- get_model_data_package_base_name (copied_model_pack_path ),
331+ model , tokenizer = self ._model_service .model , self ._model_service .tokenizer
332+ trained_model_directory = os .path .join (
333+ os .path .dirname (trained_model_pack_path ),
334+ get_model_data_package_base_name (trained_model_pack_path ),
331335 )
332336
333337 if non_default_device_is_available (self ._config .DEVICE ):
@@ -355,7 +359,7 @@ def run(
355359 ],
356360 )
357361
358- model = get_peft_model (model , lora_config )
362+ peft_model = get_peft_model (model , lora_config )
359363
360364 mlflow_logging_callback = MLflowLoggingCallback (self ._tracker_client )
361365 cancel_event_check_callback = CancelEventCheckCallback (self ._cancel_event )
@@ -378,27 +382,26 @@ def run(
378382 training_args = GRPOConfig (
379383 output_dir = results_path ,
380384 logging_dir = logs_path ,
385+ logging_steps = log_frequency ,
381386 learning_rate = 5e-6 ,
382387 adam_beta1 = 0.9 ,
383388 adam_beta2 = 0.99 ,
384389 weight_decay = 0.1 ,
385390 warmup_ratio = 0.1 ,
386391 lr_scheduler_type = "cosine" ,
387392 optim = "paged_adamw_8bit" ,
388- logging_steps = 1 ,
389393 per_device_train_batch_size = 6 , # This global batch size must be divisible by the number of generations
390394 gradient_accumulation_steps = 1 ,
391395 num_generations = 6 ,
392396 max_prompt_length = max_prompt_length ,
393397 max_completion_length = max_seq_length - max_prompt_length ,
394398 num_train_epochs = training_params ["nepochs" ],
395- max_steps = 250 ,
396399 save_steps = 250 ,
397400 max_grad_norm = 0.1 ,
398401 report_to = "none" ,
399402 )
400403 trainer = GRPOTrainer (
401- model = model ,
404+ model = peft_model ,
402405 processing_class = tokenizer ,
403406 reward_funcs = self ._get_reward_functions (),
404407 args = training_args ,
@@ -409,7 +412,7 @@ def run(
409412 else :
410413 raise ConfigurationException (f"Unsupported trainer type: { trainer_type } " )
411414
412- self ._tracker_client .log_model_config (model .config .to_dict ())
415+ self ._tracker_client .log_model_config ({ ** model .config .to_dict (), ** peft_model . peft_config } )
413416 self ._tracker_client .log_trainer_version (TrainerBackend .TRANSFORMERS , transformers_version )
414417
415418 logger .info (f"Performing { trainer_type .upper ()} training..." )
@@ -422,11 +425,13 @@ def run(
422425 model_pack_file_ext = get_model_data_package_extension (self ._config .BASE_MODEL_FILE )
423426 model_pack_file_name = f"{ ModelType .HUGGINGFACE_LLM .value } _{ run_id } { model_pack_file_ext } "
424427 retrained_model_pack_path = os .path .join (self ._retrained_models_dir , model_pack_file_name )
428+ model = peft_model .merge_and_unload ()
425429 model .save_pretrained (
426- copied_model_directory ,
430+ trained_model_directory ,
427431 safe_serialization = (self ._config .TRAINING_SAFE_MODEL_SERIALISATION == "true" ),
428432 )
429- create_model_data_package (copied_model_directory , retrained_model_pack_path )
433+ tokenizer .save_pretrained (trained_model_directory )
434+ create_model_data_package (trained_model_directory , retrained_model_pack_path )
430435 model_uri = self ._tracker_client .save_model (
431436 retrained_model_pack_path ,
432437 self ._model_name ,
@@ -475,7 +480,7 @@ def run(
475480 with self ._training_lock :
476481 self ._training_in_progress = False
477482 self ._clean_up_training_cache ()
478- self ._housekeep_file (copied_model_pack_path )
483+ self ._housekeep_file (trained_model_pack_path )
479484 if trainer is not None :
480485 del trainer
481486 gc .collect ()
@@ -505,6 +510,7 @@ def run(
505510 training_args = GRPOConfig (
506511 output_dir = results_path ,
507512 logging_dir = logs_path ,
513+ logging_steps = log_frequency ,
508514 per_device_eval_batch_size = 6 ,
509515 num_generations = 2 ,
510516 max_prompt_length = max_prompt_length ,
@@ -607,19 +613,19 @@ def correctness_reward_func(
607613 )
608614 return [2.0 if r == a else 0.0 for r , a in zip (extracted_responses , answer )]
609615
610- def int_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> list [float ]:
616+ def int_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> List [float ]:
611617 responses = [completion [0 ]["content" ] for completion in completions ]
612618 extracted_responses = [extract_xml_answer (r ) for r in responses ]
613619 return [0.5 if r .isdigit () else 0.0 for r in extracted_responses ]
614620
615- def strict_format_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> list [float ]:
621+ def strict_format_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> List [float ]:
616622 """Reward function that checks if the completion has a specific format."""
617623 pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
618624 responses = [completion [0 ]["content" ] for completion in completions ]
619625 matches = [re .match (pattern , r ) for r in responses ]
620626 return [0.5 if match else 0.0 for match in matches ]
621627
622- def soft_format_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> list [float ]:
628+ def soft_format_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> List [float ]:
623629 """Reward function that checks if the completion has a specific format."""
624630 pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
625631 responses = [completion [0 ]["content" ] for completion in completions ]
@@ -640,7 +646,7 @@ def count_xml(text: str) -> float:
640646 count -= (len (text .split ("\n </answer>" )[- 1 ]) - 1 ) * 0.001
641647 return count
642648
643- def xmlcount_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> list [float ]:
649+ def xmlcount_reward_func (completions : Tuple [Any ], ** kwargs : Dict [str , Any ]) -> List [float ]:
644650 contents = [completion [0 ]["content" ] for completion in completions ]
645651 return [count_xml (c ) for c in contents ]
646652
0 commit comments