|
19 | 19 | from tqdm import tqdm |
20 | 20 | from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed |
21 | 21 | from transformers.trainer_utils import is_main_process |
22 | | -from transformers.utils import logging |
| 22 | +from transformers.utils import is_flash_attn_2_available, logging |
23 | 23 |
|
24 | 24 | from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector |
25 | 25 | from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset |
@@ -168,7 +168,9 @@ def load_finetuned_model(model_args: ModelArguments, model_name_or_path: str) -> |
168 | 168 | torch_dtype = get_torch_dtype(model_args.torch_dtype) |
169 | 169 | try: |
170 | 170 | model = finetune_model_cls.from_pretrained( |
171 | | - model_name_or_path, torch_dtype=torch_dtype, attn_implementation=model_args.attn_implementation |
| 171 | + model_name_or_path, |
| 172 | + torch_dtype=torch_dtype, |
| 173 | + attn_implementation=("flash_attention_2" if is_flash_attn_2_available() else "eager"), |
172 | 174 | ) |
173 | 175 | if torch_dtype == torch.bfloat16: |
174 | 176 | return model.bfloat16() |
@@ -336,12 +338,7 @@ def main(): |
336 | 338 | dataset=processed_dataset["test"], |
337 | 339 | batch_size=per_device_eval_batch_size, |
338 | 340 | num_workers=training_args.dataloader_num_workers, |
339 | | - collate_fn=CehrBertDataCollator( |
340 | | - tokenizer=tokenizer, |
341 | | - max_length=config.max_position_embeddings, |
342 | | - is_pretraining=False, |
343 | | - mlm_probability=config.mlm_probability, |
344 | | - ), |
| 341 | + collate_fn=data_collator, |
345 | 342 | pin_memory=training_args.dataloader_pin_memory, |
346 | 343 | ) |
347 | 344 | do_predict(test_dataloader, model_args, training_args) |
|
0 commit comments