Skip to content

Commit 27c0e4f

Browse files
authored
Cleanup model arguments (#102)
* added attn_implementation to the model arguments * added a check on the concept_value * set None unit to a default value N/A * set None value in concept_values to 0.0 * set _supports_sdpa = True in BertPreTrainedModel * implemented flash attn * do not overwrite the attention mask when flash attention is enabled * upgraded huggingface transformers * updated the logic for splitting heads * make sure we load the model using the specified torch_dtype * set the entire model to the corresponding dtype * removed keyward arguments from hf_cehrgpt * updated BertSelfFlashAttention.forward to return a tuple because the BERT layer expects such output * test gpt2 implementation * test gpt2 implementation * pass the attn_implementation and torch_dtype to the model during fine-tuning * set the default value of torch_dtype to auto * convert age_at_index to the same data type as the bert output * added logic to convert float32 to the corresponding precision * removed mlm_skip_values * updated the unit test after removing mlm_skip_values * set the default value of torch_dtype to None * convert concept_value_masks to torch.bool before using it in torch.where * convert tensors back to the original dtype in the flash attention implementation * check if torch_dtype is null before trying to get it from torch
1 parent d13338b commit 27c0e4f

File tree

12 files changed

+324
-31
lines changed

12 files changed

+324
-31
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ dependencies = [
5252
"tqdm>=4.66.1",
5353
"torch==2.4.0",
5454
"tokenizers>=0.19.0",
55-
"transformers>=4.40.0",
55+
"transformers>=4.41.0",
5656
"accelerate>=0.31.0",
5757
"Werkzeug==3.0.1",
5858
"wandb>=0.17.8",

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,8 @@ def __call__(self, examples):
111111

112112
# This is the most crucial logic for generating the training labels
113113
if self.is_pretraining:
114-
115-
batch_mlm_skip_values = [
116-
self._convert_to_tensor(example["mlm_skip_values"]).to(torch.bool) for example in examples
117-
]
118-
batch["mlm_skip_values"] = pad_sequence(batch_mlm_skip_values, batch_first=True, padding_value=False)
119-
# Set the mlm_skip_values of the CLS token to a default value False
120-
batch["mlm_skip_values"] = torch.cat([torch.full((batch_size, 1), False), batch["mlm_skip_values"]], dim=1)
121-
122114
# If the labels field is already provided, we will build the MLM labels off of that.
123115
# The labels value indicates the positions that are not allowed for MLM.
124-
# For example, the mlm_skip_values=1, this means this is a lab value and
125-
# we don't want to predict the tokens at this position
126116
if "labels" in examples[0]:
127117
batch_labels = [self._convert_to_tensor(example["labels"]) for example in examples]
128118
batch["labels"] = pad_sequence(batch_labels, batch_first=True, padding_value=-100)

src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from meds.schema import birth_code, death_code
1818
from pandas import Series
1919

20-
from cehrbert.med_extension.schema_extension import Event, Visit
20+
from cehrbert.med_extension.schema_extension import Event
2121
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
2222
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
2323

@@ -573,17 +573,39 @@ def __init__(self, concept_tokenizer: CehrBertTokenizer, is_pretraining: bool):
573573
self._is_pretraining = is_pretraining
574574
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
575575

576+
@staticmethod
577+
def fill_na_value(values, value_to_fill):
578+
none_values = np.array([x is None for x in values])
579+
if none_values.any():
580+
values = values.copy()
581+
values[none_values] = value_to_fill
582+
return values
583+
576584
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
577585

578586
input_ids = self._concept_tokenizer.encode(record["concept_ids"])
579587
record["input_ids"] = input_ids
580588
concept_value_masks = record["concept_value_masks"]
589+
590+
# These fields may not exist in the old version of the datasets
591+
if "units" in record:
592+
record["units"] = self.fill_na_value(record["units"], NA)
593+
if "concept_as_values" in record:
594+
record["concept_as_values"] = self.fill_na_value(record["concept_as_values"], NA)
595+
581596
# Backward compatibility
582597
if "concept_values" not in record:
583598
record["concept_values"] = record["number_as_values"]
584599

585-
if np.isnan(record["concept_values"]).any():
586-
record["concept_values"] = [v if not pd.isna(v) else 0.0 for v in record["concept_values"]]
600+
concept_value_is_nan = np.isnan(record["concept_values"])
601+
if concept_value_is_nan.any():
602+
# Create a writeable copy
603+
concept_value_masks = concept_value_masks.copy()
604+
concept_value_masks[concept_value_is_nan] = 0
605+
record["concept_value_masks"] = concept_value_masks
606+
concept_values = record["concept_values"].copy()
607+
concept_values[concept_value_is_nan] = 0.0
608+
record["concept_values"] = concept_values
587609

588610
assert len(input_ids) == len(
589611
record["concept_ids"]

0 commit comments

Comments
 (0)