1+ import glob
12import json
23import os
4+ import shutil
35from datetime import datetime
46from functools import partial
57from pathlib import Path
3537)
3638from cehrbert .models .hf_models .tokenization_hf_cehrbert import CehrBertTokenizer
3739from cehrbert .runners .hf_runner_argument_dataclass import DataTrainingArguments , FineTuneModelType , ModelArguments
40+ from cehrbert .runners .hyperparameter_search_util import perform_hyperparameter_search
3841from cehrbert .runners .runner_util import (
3942 convert_dataset_to_iterable_dataset ,
4043 generate_prepared_ds_path ,
@@ -280,50 +283,53 @@ def main():
280283 )
281284
282285 if training_args .do_train :
283- model = load_finetuned_model (model_args , model_args .model_name_or_path )
284- if getattr (model .config , "cls_token_id" ) is None :
285- model .config .cls_token_id = tokenizer .cls_token_index
286- # If lora is enabled, we add LORA adapters to the model
287- if model_args .use_lora :
288- # When LORA is used, the trainer could not automatically find this label,
289- # therefore we need to manually set label_names to "classifier_label" so the model
290- # can compute the loss during the evaluation
291- if training_args .label_names :
292- training_args .label_names .append ("classifier_label" )
293- else :
294- training_args .label_names = ["classifier_label" ]
295-
296- if model_args .finetune_model_type == FineTuneModelType .POOLING .value :
297- config = LoraConfig (
298- r = model_args .lora_rank ,
299- lora_alpha = model_args .lora_alpha ,
300- target_modules = model_args .target_modules ,
301- lora_dropout = model_args .lora_dropout ,
302- bias = "none" ,
303- modules_to_save = ["classifier" , "age_batch_norm" , "dense_layer" ],
304- )
305- model = get_peft_model (model , config )
306- else :
307- raise ValueError (f"The LORA adapter is not supported for { model_args .finetune_model_type } " )
308-
309- trainer = trainer_class (
310- model = model ,
311- data_collator = data_collator ,
312- train_dataset = processed_dataset ["train" ],
313- eval_dataset = processed_dataset ["validation" ],
314- args = training_args ,
315- callbacks = [EarlyStoppingCallback (early_stopping_patience = model_args .early_stopping_patience )],
316- )
317-
318- checkpoint = get_last_hf_checkpoint (training_args )
286+ output_dir = training_args .output_dir
287+ if cehrbert_args .hyperparameter_tuning :
288+ training_args , run_id = perform_hyperparameter_search (
289+ trainer_class ,
290+ partial (model_init , model_args , training_args , tokenizer ),
291+ processed_dataset ,
292+ data_collator ,
293+ training_args ,
294+ model_args ,
295+ cehrbert_args ,
296+ )
297+ # We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
298+ cehrbert_args .retrain_with_full |= cehrbert_args .hyperparameter_tuning_percentage < 1.0
299+ output_dir = os .path .join (training_args .output_dir , f"run-{ run_id } " )
300+
301+ if cehrbert_args .hyperparameter_tuning and not cehrbert_args .retrain_with_full :
302+ folders = glob .glob (os .path .join (output_dir , "checkpoint-*" ))
303+ if len (folders ) == 0 :
304+ raise RuntimeError (f"There must be a checkpoint folder under { output_dir } " )
305+ checkpoint_dir = folders [0 ]
306+ LOG .info ("Best trial checkpoint folder: %s" , checkpoint_dir )
307+ for file_name in os .listdir (checkpoint_dir ):
308+ try :
309+ full_file_name = os .path .join (checkpoint_dir , file_name )
310+ destination = os .path .join (training_args .output_dir , file_name )
311+ if os .path .isfile (full_file_name ):
312+ shutil .copy2 (full_file_name , destination )
313+ except Exception as e :
314+ LOG .error ("Failed to copy %s: %s" , file_name , str (e ))
315+ else :
316+ trainer = trainer_class (
317+ model = model_init (model_args , training_args , tokenizer ),
318+ data_collator = data_collator ,
319+ train_dataset = processed_dataset ["train" ],
320+ eval_dataset = processed_dataset ["validation" ],
321+ args = training_args ,
322+ callbacks = [EarlyStoppingCallback (early_stopping_patience = model_args .early_stopping_patience )],
323+ )
319324
320- train_result = trainer .train (resume_from_checkpoint = checkpoint )
321- trainer .save_model () # Saves the tokenizer too for easy upload
322- metrics = train_result .metrics
325+ checkpoint = get_last_hf_checkpoint (training_args )
326+ train_result = trainer .train (resume_from_checkpoint = checkpoint )
327+ trainer .save_model () # Saves the tokenizer too for easy upload
328+ metrics = train_result .metrics
323329
324- trainer .log_metrics ("train" , metrics )
325- trainer .save_metrics ("train" , metrics )
326- trainer .save_state ()
330+ trainer .log_metrics ("train" , metrics )
331+ trainer .save_metrics ("train" , metrics )
332+ trainer .save_state ()
327333
328334 if training_args .do_predict :
329335 test_dataloader = DataLoader (
@@ -341,6 +347,35 @@ def main():
341347 do_predict (test_dataloader , model_args , training_args )
342348
343349
350+ def model_init (model_args , training_args , tokenizer ):
351+ model = load_finetuned_model (model_args , model_args .model_name_or_path )
352+ if getattr (model .config , "cls_token_id" ) is None :
353+ model .config .cls_token_id = tokenizer .cls_token_index
354+ # If lora is enabled, we add LORA adapters to the model
355+ if model_args .use_lora :
356+ # When LORA is used, the trainer could not automatically find this label,
357+ # therefore we need to manually set label_names to "classifier_label" so the model
358+ # can compute the loss during the evaluation
359+ if training_args .label_names :
360+ training_args .label_names .append ("classifier_label" )
361+ else :
362+ training_args .label_names = ["classifier_label" ]
363+
364+ if model_args .finetune_model_type == FineTuneModelType .POOLING .value :
365+ config = LoraConfig (
366+ r = model_args .lora_rank ,
367+ lora_alpha = model_args .lora_alpha ,
368+ target_modules = model_args .target_modules ,
369+ lora_dropout = model_args .lora_dropout ,
370+ bias = "none" ,
371+ modules_to_save = ["classifier" , "age_batch_norm" , "dense_layer" ],
372+ )
373+ model = get_peft_model (model , config )
374+ else :
375+ raise ValueError (f"The LORA adapter is not supported for { model_args .finetune_model_type } " )
376+ return model
377+
378+
344379def do_predict (test_dataloader : DataLoader , model_args : ModelArguments , training_args : TrainingArguments ):
345380 """
346381 Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics.
0 commit comments