Skip to content

Commit 0b364e1

Browse files
committed
fixed a bug where the sample packing was never used in do_predict
1 parent 1d3b08d commit 0b364e1

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

src/cehrbert/runners/hf_cehrbert_finetune_runner.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed
2121
from transformers.trainer_utils import is_main_process
22-
from transformers.utils import logging
22+
from transformers.utils import is_flash_attn_2_available, logging
2323

2424
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
2525
from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset
@@ -168,7 +168,9 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) ->
168168
torch_dtype = get_torch_dtype(model_args.torch_dtype)
169169
try:
170170
model = finetune_model_cls.from_pretrained(
171-
model_name_or_path, torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation
171+
model_name_or_path,
172+
torch_dtype=torch_dtype,
173+
attn_implementation=("flash_attention_2" if is_flash_attn_2_available() else "eager"),
172174
)
173175
if torch_dtype == torch.bfloat16:
174176
return model.bfloat16()
@@ -336,12 +338,7 @@ def main():
336338
dataset=processed_dataset["test"],
337339
batch_size=per_device_eval_batch_size,
338340
num_workers=training_args.dataloader_num_workers,
339-
collate_fn=CehrBertDataCollator(
340-
tokenizer=tokenizer,
341-
max_length=config.max_position_embeddings,
342-
is_pretraining=False,
343-
mlm_probability=config.mlm_probability,
344-
),
341+
collate_fn=data_collator,
345342
pin_memory=training_args.dataloader_pin_memory,
346343
)
347344
do_predict(test_dataloader, model_args, training_args)

0 commit comments

Comments
 (0)