|
27 | 27 | ) |
28 | 28 | from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping |
29 | 29 | from cehrbert.data_generators.hf_data_generator.meds_utils import create_dataset_from_meds_reader |
30 | | -from cehrbert.data_generators.hf_data_generator.sample_packing_sampler import SamplePackingBatchSampler |
31 | 30 | from cehrbert.models.hf_models.config import CehrBertConfig |
32 | | -from cehrbert.models.hf_models.hf_cehrbert import CehrBertForClassification, CehrBertLstmForClassification |
| 31 | +from cehrbert.models.hf_models.hf_cehrbert import ( |
| 32 | + CehrBertForClassification, |
| 33 | + CehrBertLstmForClassification, |
| 34 | + CehrBertPreTrainedModel, |
| 35 | +) |
33 | 36 | from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer |
34 | 37 | from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, FineTuneModelType, ModelArguments |
35 | 38 | from cehrbert.runners.runner_util import ( |
@@ -148,7 +151,7 @@ def load_pretrained_tokenizer( |
148 | 151 | raise ValueError(f"Can not load the pretrained tokenizer from {tokenizer_name_or_path}") |
149 | 152 |
|
150 | 153 |
|
151 | | -def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> PreTrainedModel: |
| 154 | +def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> CehrBertPreTrainedModel: |
152 | 155 | if model_args.finetune_model_type == FineTuneModelType.POOLING.value: |
153 | 156 | finetune_model_cls = CehrBertForClassification |
154 | 157 | elif model_args.finetune_model_type == FineTuneModelType.LSTM.value: |
@@ -204,6 +207,12 @@ def main(): |
204 | 207 | # Organize them into a single DatasetDict |
205 | 208 | final_splits = prepare_finetune_dataset(data_args, training_args, cache_file_collector) |
206 | 209 |
|
| 210 | + # TODO: temp solution, this column is mixed typed and causes an issue when transforming the data |
| 211 | + if not data_args.streaming: |
| 212 | + all_columns = final_splits["train"].column_names |
| 213 | + if "visit_concept_ids" in all_columns: |
| 214 | + final_splits = final_splits.remove_columns(["visit_concept_ids"]) |
| 215 | + |
207 | 216 | processed_dataset = create_cehrbert_finetuning_dataset( |
208 | 217 | dataset=final_splits, |
209 | 218 | concept_tokenizer=tokenizer, |
|
0 commit comments