Skip to content

Commit ea7068b

Browse files
committed
removed torch_dtype when loading the model
1 parent 5ad9876 commit ea7068b

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from transformers import TrainingArguments
1515
from transformers.utils import is_flash_attn_2_available, logging
1616

17-
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
1817
from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset
1918
from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import (
2019
CehrBertDataCollator,
@@ -29,7 +28,6 @@
2928
from cehrbert.runners.runner_util import (
3029
convert_dataset_to_iterable_dataset,
3130
generate_prepared_ds_path,
32-
get_last_hf_checkpoint,
3331
get_meds_extension_path,
3432
load_parquet_as_dataset,
3533
parse_runner_args,
@@ -108,7 +106,6 @@ def main():
108106
CehrBert.from_pretrained(
109107
model_args.model_name_or_path,
110108
attn_implementation=("flash_attention_2" if is_flash_attn_2_available() else "eager"),
111-
torch_dtype=(torch.bfloat16 if is_flash_attn_2_available() else torch.float32),
112109
)
113110
.eval()
114111
.to(device)
@@ -290,7 +287,7 @@ def main():
290287
cls_token_indices = batch["input_ids"] == cehrgpt_tokenizer.cls_token_index
291288
if cehrbert_args.sample_packing:
292289
features = (
293-
cehrgpt_output.last_hidden_state[..., cls_token_indices, :]
290+
cehrgpt_output.last_hidden_state[cls_token_indices]
294291
.cpu()
295292
.float()
296293
.detach()

0 commit comments

Comments
 (0)