File tree Expand file tree Collapse file tree 1 file changed +1
-4
lines changed
Expand file tree Collapse file tree 1 file changed +1
-4
lines changed Original file line number Diff line number Diff line change 1414from transformers import TrainingArguments
1515from transformers .utils import is_flash_attn_2_available , logging
1616
17- from cehrbert .data_generators .hf_data_generator .cache_util import CacheFileCollector
1817from cehrbert .data_generators .hf_data_generator .hf_dataset import create_cehrbert_finetuning_dataset
1918from cehrbert .data_generators .hf_data_generator .hf_dataset_collator import (
2019 CehrBertDataCollator ,
2928from cehrbert .runners .runner_util import (
3029 convert_dataset_to_iterable_dataset ,
3130 generate_prepared_ds_path ,
32- get_last_hf_checkpoint ,
3331 get_meds_extension_path ,
3432 load_parquet_as_dataset ,
3533 parse_runner_args ,
@@ -108,7 +106,6 @@ def main():
108106 CehrBert .from_pretrained (
109107 model_args .model_name_or_path ,
110108 attn_implementation = ("flash_attention_2" if is_flash_attn_2_available () else "eager" ),
111- torch_dtype = (torch .bfloat16 if is_flash_attn_2_available () else torch .float32 ),
112109 )
113110 .eval ()
114111 .to (device )
@@ -290,7 +287,7 @@ def main():
290287 cls_token_indices = batch ["input_ids" ] == cehrgpt_tokenizer .cls_token_index
291288 if cehrbert_args .sample_packing :
292289 features = (
293- cehrgpt_output .last_hidden_state [..., cls_token_indices , : ]
290+ cehrgpt_output .last_hidden_state [cls_token_indices ]
294291 .cpu ()
295292 .float ()
296293 .detach ()
You can’t perform that action at this time.
0 commit comments