Commit 27c0e4f
authored
Cleanup model arguments (#102)
* added attn_implementation to the model arguments
* added a check on the concept_value
* set None unit to a default value N/A
* set None value in concept_values to 0.0
* set _supports_sdpa = True in BertPreTrainedModel
* implemented flash attn
* do not overwrite the attention mask when flash attention is enabled
* upgraded huggingface transformers
* updated the logic for splitting heads
* make sure we load the model using the specified torch_dtype
* set the entire model to the corresponding dtype
* removed keyward arguments from hf_cehrgpt
* updated BertSelfFlashAttention.forward to return a tuple because the BERT layer expects such output
* test gpt2 implementation
* test gpt2 implementation
* pass the attn_implementation and torch_dtype to the model during fine-tuning
* set the default value of torch_dtype to auto
* convert age_at_index to the same data type as the bert output
* added logic to convert float32 to the corresponding precision
* removed mlm_skip_values
* updated the unit test after removing mlm_skip_values
* set the default value of torch_dtype to None
* convert concept_value_masks to torch.bool before using it in torch.where
* convert tensors back to the original dtype in the flash attention implementation
* check if torch_dtype is null before trying to get it from torch1 parent d13338b commit 27c0e4f
File tree
12 files changed
+324
-31
lines changed- src/cehrbert
- data_generators/hf_data_generator
- models/hf_models
- runners
- tests
- integration_tests/runners
- unit_tests/models/hf_models
12 files changed
+324
-31
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
52 | 52 | | |
53 | 53 | | |
54 | 54 | | |
55 | | - | |
| 55 | + | |
56 | 56 | | |
57 | 57 | | |
58 | 58 | | |
| |||
Lines changed: 0 additions & 10 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
111 | 111 | | |
112 | 112 | | |
113 | 113 | | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | 114 | | |
123 | 115 | | |
124 | | - | |
125 | | - | |
126 | 116 | | |
127 | 117 | | |
128 | 118 | | |
| |||
Lines changed: 25 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
| 20 | + | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
| |||
573 | 573 | | |
574 | 574 | | |
575 | 575 | | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
| 583 | + | |
576 | 584 | | |
577 | 585 | | |
578 | 586 | | |
579 | 587 | | |
580 | 588 | | |
| 589 | + | |
| 590 | + | |
| 591 | + | |
| 592 | + | |
| 593 | + | |
| 594 | + | |
| 595 | + | |
581 | 596 | | |
582 | 597 | | |
583 | 598 | | |
584 | 599 | | |
585 | | - | |
586 | | - | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
587 | 609 | | |
588 | 610 | | |
589 | 611 | | |
| |||
0 commit comments