Skip to content

Commit 467e98f

Browse files
committed
fixed the integration test
1 parent 3c9216e commit 467e98f

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727
)
2828
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping
2929
from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader
30-
from cehrbert.data_generators.hf_data_generator.sample_packing_sampler import SamplePackingBatchSampler
3130
from cehrbert.models.hf_models.config import CehrBertConfig
32-
from cehrbert.models.hf_models.hf_cehrbert import CehrBertForClassification, CehrBertLstmForClassification
31+
from cehrbert.models.hf_models.hf_cehrbert import (
32+
CehrBertForClassification,
33+
CehrBertLstmForClassification,
34+
CehrBertPreTrainedModel,
35+
)
3336
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
3437
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, FineTuneModelType, ModelArguments
3538
from cehrbert.runners.runner_util import (
@@ -148,7 +151,7 @@ def load_pretrained_tokenizer(
148151
raise ValueError(f"Can not load the pretrained tokenizer from {tokenizer_name_or_path}")
149152

150153

151-
def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> PreTrainedModel:
154+
def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> CehrBertPreTrainedModel:
152155
if model_args.finetune_model_type == FineTuneModelType.POOLING.value:
153156
finetune_model_cls = CehrBertForClassification
154157
elif model_args.finetune_model_type == FineTuneModelType.LSTM.value:
@@ -204,6 +207,12 @@ def main():
204207
# Organize them into a single DatasetDict
205208
final_splits = prepare_finetune_dataset(data_args, training_args, cache_file_collector)
206209

210+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
211+
if not data_args.streaming:
212+
all_columns = final_splits["train"].column_names
213+
if "visit_concept_ids" in all_columns:
214+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
215+
207216
processed_dataset = create_cehrbert_finetuning_dataset(
208217
dataset=final_splits,
209218
concept_tokenizer=tokenizer,

tests/integration_tests/runners/hf_cehrbert_pretrain_runner_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ def test_train_model(self):
4949
self.dataset_prepared_path,
5050
"--max_steps",
5151
"10",
52+
"--save_strategy",
53+
"steps",
54+
"--eval_strategy",
55+
"steps",
56+
"--do_train",
57+
"true",
58+
"--do_predict",
59+
"true",
60+
"--load_best_model_at_end",
61+
"true",
5262
"--report_to",
5363
"none",
5464
]

0 commit comments

Comments
 (0)