Skip to content

Commit 500dfba

Browse files
committed
set None unit to a default value N/A
1 parent 422481c commit 500dfba

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from meds.schema import birth_code, death_code
1818
from pandas import Series
1919

20-
from cehrbert.med_extension.schema_extension import Event, Visit
20+
from cehrbert.med_extension.schema_extension import Event
2121
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
2222
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
2323

@@ -573,11 +573,23 @@ def __init__(self, concept_tokenizer: CehrBertTokenizer, is_pretraining: bool):
573573
self._is_pretraining = is_pretraining
574574
self._lab_token_ids = self._concept_tokenizer.lab_token_ids
575575

576+
@staticmethod
577+
def fill_na_value(values, value_to_fill):
578+
none_values = np.array([x is None for x in values])
579+
if none_values.any():
580+
values = values.copy()
581+
values[none_values] = value_to_fill
582+
return values
583+
576584
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
577585

578586
input_ids = self._concept_tokenizer.encode(record["concept_ids"])
579587
record["input_ids"] = input_ids
580588
concept_value_masks = record["concept_value_masks"]
589+
590+
record["units"] = self.fill_na_value(record["units"], NA)
591+
record["concept_as_values"] = self.fill_na_value(record["concept_as_values"], NA)
592+
581593
# Backward compatibility
582594
if "concept_values" not in record:
583595
record["concept_values"] = record["number_as_values"]

0 commit comments

Comments
 (0)