Skip to content

Commit b1bedf0

Browse files
committed
added hyperparameter tuning
1 parent aba565e commit b1bedf0

File tree

4 files changed

+338
-43
lines changed

4 files changed

+338
-43
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ dependencies = [
5757
"Werkzeug==3.0.1",
5858
"wandb>=0.17.8",
5959
"xgboost==2.0.3",
60-
"cehrbert_data>=0.0.5"
60+
"cehrbert_data>=0.0.5",
61+
"optuna>=4.0.0",
6162
]
6263

6364
[tool.setuptools_scm]

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 77 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import glob
12
import json
23
import os
4+
import shutil
35
from datetime import datetime
46
from functools import partial
57
from pathlib import Path
@@ -35,6 +37,7 @@
3537
)
3638
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
3739
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, FineTuneModelType, ModelArguments
40+
from cehrbert.runners.hyperparameter_search_util import perform_hyperparameter_search
3841
from cehrbert.runners.runner_util import (
3942
convert_dataset_to_iterable_dataset,
4043
generate_prepared_ds_path,
@@ -280,50 +283,53 @@ def main():
280283
)
281284

282285
if training_args.do_train:
283-
model = load_finetuned_model(model_args, model_args.model_name_or_path)
284-
if getattr(model.config, "cls_token_id") is None:
285-
model.config.cls_token_id = tokenizer.cls_token_index
286-
# If lora is enabled, we add LORA adapters to the model
287-
if model_args.use_lora:
288-
# When LORA is used, the trainer could not automatically find this label,
289-
# therefore we need to manually set label_names to "classifier_label" so the model
290-
# can compute the loss during the evaluation
291-
if training_args.label_names:
292-
training_args.label_names.append("classifier_label")
293-
else:
294-
training_args.label_names = ["classifier_label"]
295-
296-
if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
297-
config = LoraConfig(
298-
r=model_args.lora_rank,
299-
lora_alpha=model_args.lora_alpha,
300-
target_modules=model_args.target_modules,
301-
lora_dropout=model_args.lora_dropout,
302-
bias="none",
303-
modules_to_save=["classifier", "age_batch_norm", "dense_layer"],
304-
)
305-
model = get_peft_model(model, config)
306-
else:
307-
raise ValueError(f"The LORA adapter is not supported for {model_args.finetune_model_type}")
308-
309-
trainer = trainer_class(
310-
model=model,
311-
data_collator=data_collator,
312-
train_dataset=processed_dataset["train"],
313-
eval_dataset=processed_dataset["validation"],
314-
args=training_args,
315-
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)],
316-
)
317-
318-
checkpoint = get_last_hf_checkpoint(training_args)
286+
output_dir = training_args.output_dir
287+
if cehrbert_args.hyperparameter_tuning:
288+
training_args, run_id = perform_hyperparameter_search(
289+
trainer_class,
290+
partial(model_init, model_args, training_args, tokenizer),
291+
processed_dataset,
292+
data_collator,
293+
training_args,
294+
model_args,
295+
cehrbert_args,
296+
)
297+
# We enforce retraining if cehrgpt_args.hyperparameter_tuning_percentage < 1.0
298+
cehrbert_args.retrain_with_full |= cehrbert_args.hyperparameter_tuning_percentage < 1.0
299+
output_dir = os.path.join(training_args.output_dir, f"run-{run_id}")
300+
301+
if cehrbert_args.hyperparameter_tuning and not cehrbert_args.retrain_with_full:
302+
folders = glob.glob(os.path.join(output_dir, "checkpoint-*"))
303+
if len(folders) == 0:
304+
raise RuntimeError(f"There must be a checkpoint folder under {output_dir}")
305+
checkpoint_dir = folders[0]
306+
LOG.info("Best trial checkpoint folder: %s", checkpoint_dir)
307+
for file_name in os.listdir(checkpoint_dir):
308+
try:
309+
full_file_name = os.path.join(checkpoint_dir, file_name)
310+
destination = os.path.join(training_args.output_dir, file_name)
311+
if os.path.isfile(full_file_name):
312+
shutil.copy2(full_file_name, destination)
313+
except Exception as e:
314+
LOG.error("Failed to copy %s: %s", file_name, str(e))
315+
else:
316+
trainer = trainer_class(
317+
model=model_init(model_args, training_args, tokenizer),
318+
data_collator=data_collator,
319+
train_dataset=processed_dataset["train"],
320+
eval_dataset=processed_dataset["validation"],
321+
args=training_args,
322+
callbacks=[EarlyStoppingCallback(early_stopping_patience=model_args.early_stopping_patience)],
323+
)
319324

320-
train_result = trainer.train(resume_from_checkpoint=checkpoint)
321-
trainer.save_model() # Saves the tokenizer too for easy upload
322-
metrics = train_result.metrics
325+
checkpoint = get_last_hf_checkpoint(training_args)
326+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
327+
trainer.save_model() # Saves the tokenizer too for easy upload
328+
metrics = train_result.metrics
323329

324-
trainer.log_metrics("train", metrics)
325-
trainer.save_metrics("train", metrics)
326-
trainer.save_state()
330+
trainer.log_metrics("train", metrics)
331+
trainer.save_metrics("train", metrics)
332+
trainer.save_state()
327333

328334
if training_args.do_predict:
329335
test_dataloader = DataLoader(
@@ -341,6 +347,35 @@ def main():
341347
do_predict(test_dataloader, model_args, training_args)
342348

343349

350+
def model_init(model_args, training_args, tokenizer):
351+
model = load_finetuned_model(model_args, model_args.model_name_or_path)
352+
if getattr(model.config, "cls_token_id") is None:
353+
model.config.cls_token_id = tokenizer.cls_token_index
354+
# If lora is enabled, we add LORA adapters to the model
355+
if model_args.use_lora:
356+
# When LORA is used, the trainer could not automatically find this label,
357+
# therefore we need to manually set label_names to "classifier_label" so the model
358+
# can compute the loss during the evaluation
359+
if training_args.label_names:
360+
training_args.label_names.append("classifier_label")
361+
else:
362+
training_args.label_names = ["classifier_label"]
363+
364+
if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
365+
config = LoraConfig(
366+
r=model_args.lora_rank,
367+
lora_alpha=model_args.lora_alpha,
368+
target_modules=model_args.target_modules,
369+
lora_dropout=model_args.lora_dropout,
370+
bias="none",
371+
modules_to_save=["classifier", "age_batch_norm", "dense_layer"],
372+
)
373+
model = get_peft_model(model, config)
374+
else:
375+
raise ValueError(f"The LORA adapter is not supported for {model_args.finetune_model_type}")
376+
return model
377+
378+
344379
def do_predict(test_dataloader: DataLoader, model_args: ModelArguments, training_args: TrainingArguments):
345380
"""
346381
Performs inference on the test dataset using a fine-tuned model, saves predictions and evaluation metrics.

src/cehrbert/runners/hf_runner_argument_dataclass.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,37 @@ class CehrBertArguments:
352352
default=False,
353353
metadata={"help": "Whether or not to average tokens per sequence"},
354354
)
355+
retrain_with_full: Optional[bool] = dataclasses.field(
356+
default=False,
357+
metadata={
358+
"help": "A flag to indicate whether we want to retrain the model on the full set after early stopping"
359+
},
360+
)
361+
hyperparameter_tuning: Optional[bool] = dataclasses.field(
362+
default=False,
363+
metadata={"help": "A flag to indicate if we want to do hyperparameter tuning."},
364+
)
365+
hyperparameter_tuning_percentage: Optional[float] = dataclasses.field(
366+
default=0.1,
367+
metadata={"help": "The percentage of the train/val will be use for hyperparameter tuning."},
368+
)
369+
n_trials: Optional[int] = dataclasses.field(
370+
default=10,
371+
metadata={"help": "The number of trails will be use for hyperparameter tuning."},
372+
)
373+
hyperparameter_batch_sizes: Optional[List[int]] = dataclasses.field(
374+
default_factory=lambda: [4, 8, 16],
375+
metadata={"help": "Hyperparameter search batch sizes"},
376+
)
377+
hyperparameter_num_train_epochs: Optional[List[int]] = dataclasses.field(
378+
default_factory=lambda: [10],
379+
metadata={"help": "Hyperparameter search num_train_epochs"},
380+
)
381+
hyperparameter_learning_rates: Optional[List[int]] = dataclasses.field(
382+
default_factory=lambda: [1e-5],
383+
metadata={"help": "Hyperparameter search learning rates"},
384+
)
385+
hyperparameter_weight_decays: Optional[List[int]] = dataclasses.field(
386+
default_factory=lambda: [1e-2],
387+
metadata={"help": "Hyperparameter search learning rates"},
388+
)

0 commit comments

Comments
 (0)