Skip to content

Commit 01571d6

Browse files
committed
added logic to slice out records from the tokenized dataset
1 parent e84b131 commit 01571d6

File tree

7 files changed

+339
-32
lines changed

7 files changed

+339
-32
lines changed

src/cehrbert/cehrbert_utils.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import re
2+
from datetime import datetime, timedelta, timezone
3+
from typing import List, Optional, Union
4+
5+
from transformers.utils import logging
6+
7+
# Regular expression pattern to match inpatient attendance tokens
8+
MEDS_CODE_PATTERN = re.compile(r".*/.*")
9+
INPATIENT_ATT_PATTERN = re.compile(r"(?:VS-|i-)D(\d+)(?:-VE)?")
10+
DEMOGRAPHIC_PROMPT_SIZE = 4
11+
logger = logging.get_logger("transformers")
12+
13+
14+
def construct_time_sequence(
15+
concept_ids: List[str], epoch_times: Optional[List[Union[int, float]]] = None
16+
) -> List[float]:
17+
if epoch_times is not None:
18+
return epoch_times
19+
20+
if concept_ids[0].lower().startswith("year"):
21+
year_str = concept_ids[0].split(":")[1]
22+
else:
23+
year_str = "1985"
24+
25+
datetime_cursor = datetime(int(year_str), month=1, day=1, hour=0, minute=0, second=0).replace(tzinfo=timezone.utc)
26+
epoch_times = []
27+
for concept_id in concept_ids:
28+
if is_att_token(concept_id):
29+
att_days = extract_time_interval_in_days(concept_id)
30+
datetime_cursor += timedelta(days=att_days)
31+
epoch_times.append(datetime_cursor.timestamp())
32+
return epoch_times
33+
34+
35+
def is_att_token(token: str):
36+
"""
37+
Check if the token is an attention token.
38+
39+
:param token: Token to check.
40+
:return: True if the token is an attention token, False otherwise.
41+
"""
42+
if bool(re.match(r"^D\d+", token)): # day tokens
43+
return True
44+
elif bool(re.match(r"^W\d+", token)): # week tokens
45+
return True
46+
elif bool(re.match(r"^M\d+", token)): # month tokens
47+
return True
48+
elif bool(re.match(r"^Y\d+", token)): # year tokens
49+
return True
50+
elif token == "LT":
51+
return True
52+
elif token[:3] == "VS-": # VS-D7-VE
53+
return True
54+
elif token[:2] == "i-" and not token.startswith("i-H"): # i-D7 and exclude hour tokens
55+
return True
56+
return False
57+
58+
59+
def extract_time_interval_in_days(token: str):
60+
"""
61+
Extract the time interval in days from a token.
62+
63+
:param token: Token to extract from.
64+
:return: Time interval in days.
65+
:raises ValueError: If the token is invalid.
66+
"""
67+
try:
68+
if token[0] == "D": # day tokens
69+
return int(token[1:])
70+
elif token[0] == "W": # week tokens
71+
return int(token[1:]) * 7
72+
elif token[0] == "M": # month tokens
73+
return int(token[1:]) * 30
74+
elif token[0] == "Y": # year tokens
75+
return int(token[1:]) * 365
76+
elif token == "LT":
77+
return 365 * 3
78+
elif token[:3] == "VS-": # VS-D7-VE
79+
part = token.split("-")[1]
80+
if part.startswith("LT"):
81+
return 365 * 3
82+
return int(part[1:])
83+
elif token[:2] == "i-": # i-D7
84+
part = token.split("-")[1]
85+
if part.startswith("LT"):
86+
return 365 * 3
87+
return int(token.split("-")[1][1:])
88+
except Exception:
89+
raise ValueError(f"Invalid time token: {token}")
90+
raise ValueError(f"Invalid time token: {token}")

src/cehrbert/data_generators/hf_data_generator/hf_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"num_of_visits",
2727
"number_as_values",
2828
"concept_as_values",
29+
"epoch_times",
2930
]
3031

3132
TRANSFORMER_COLUMNS = ["input_ids", "labels"]

src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import itertools
55
import re
66
from abc import ABC, abstractmethod
7+
from collections import defaultdict
78
from dataclasses import dataclass
89
from enum import Enum
910
from typing import Any, Dict, Generator, List, Optional, Union
@@ -17,6 +18,7 @@
1718
from meds.schema import birth_code, death_code
1819
from pandas import Series
1920

21+
from cehrbert.cehrbert_utils import construct_time_sequence
2022
from cehrbert.med_extension.schema_extension import Event
2123
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
2224
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
@@ -284,6 +286,7 @@ def remove_columns(self):
284286
def _update_cehrbert_record(
285287
cehrbert_record: Dict[str, Any],
286288
code: str,
289+
time: datetime.datetime,
287290
visit_segment: int = 0,
288291
date: int = 0,
289292
age: int = -1,
@@ -304,6 +307,7 @@ def _update_cehrbert_record(
304307
cehrbert_record["concept_values"].append(concept_value)
305308
cehrbert_record["units"].append(unit)
306309
cehrbert_record["mlm_skip_values"].append(mlm_skip_value)
310+
cehrbert_record["epoch_times"].append(time.replace(tzinfo=datetime.timezone.utc).timestamp())
307311

308312
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
309313

@@ -320,6 +324,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
320324
"units": [],
321325
"mlm_skip_values": [],
322326
"visit_concept_ids": [],
327+
"epoch_times": [],
323328
}
324329
# Extract the demographic information
325330
birth_datetime = record["birth_datetime"]
@@ -340,7 +345,10 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
340345
year_str = f"year:{str(first_visit_start_datetime.year)}"
341346
age_str = f"age:{str(relativedelta(first_visit_start_datetime, birth_datetime).years)}"
342347

343-
self._update_cehrbert_record(cehrbert_record, year_str)
348+
self._update_cehrbert_record(
349+
cehrbert_record,
350+
year_str,
351+
)
344352
self._update_cehrbert_record(cehrbert_record, age_str)
345353
self._update_cehrbert_record(cehrbert_record, gender)
346354
self._update_cehrbert_record(cehrbert_record, race)
@@ -377,6 +385,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
377385
cehrbert_record,
378386
code=self._time_token_function(time_delta),
379387
visit_concept_order=i + 1,
388+
time=visit_start_datetime,
380389
)
381390

382391
# Add the VS token to the patient timeline to mark the start of a visit
@@ -393,6 +402,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
393402
date=date,
394403
visit_segment=visit_segment,
395404
visit_concept_id=visit_type,
405+
time=date_cursor,
396406
)
397407

398408
if self._include_auxiliary_token:
@@ -404,6 +414,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
404414
date=date,
405415
visit_segment=visit_segment,
406416
visit_concept_id=visit_type,
417+
time=date_cursor,
407418
)
408419
# Keep track of the existing outpatient events, we don't want to add them again
409420
existing_outpatient_events = list()
@@ -450,6 +461,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
450461
visit_concept_order=i + 1,
451462
visit_segment=visit_segment,
452463
visit_concept_id=visit_type,
464+
time=date_cursor,
453465
)
454466
else:
455467
# For outpatient visits, we use the visit time stamp to calculate age and time because we assume
@@ -471,6 +483,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
471483
concept_value=concept_value,
472484
unit=unit,
473485
mlm_skip_value=concept_value_mask,
486+
time=date_cursor,
474487
)
475488
existing_outpatient_events.append((date, code, concept_value))
476489

@@ -496,6 +509,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
496509
visit_concept_order=i + 1,
497510
visit_segment=visit_segment,
498511
visit_concept_id=visit_type,
512+
time=date_cursor,
499513
)
500514

501515
# Reuse the age and date calculated for the last event in the patient timeline
@@ -507,6 +521,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
507521
visit_concept_order=i + 1,
508522
visit_segment=visit_segment,
509523
visit_concept_id=visit_type,
524+
time=date_cursor,
510525
)
511526

512527
# Toggle visit_segment_indicator
@@ -519,11 +534,17 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
519534
cehrbert_record["num_of_concepts"] = len(cehrbert_record["concept_ids"])
520535
cehrbert_record["num_of_visits"] = len(visits)
521536

537+
if record.get("index_date", None) is not None:
538+
cehrbert_record["index_date"] = record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
522539
if "label" in record:
523540
cehrbert_record["label"] = record["label"]
524541
if "age_at_index" in record:
525542
cehrbert_record["age_at_index"] = record["age_at_index"]
526543

544+
assert len(cehrbert_record["epoch_times"]) == len(
545+
cehrbert_record["concept_ids"]
546+
), "The number of time stamps must match with the number of concepts in the sequence"
547+
527548
return cehrbert_record
528549

529550

@@ -594,6 +615,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
594615
input_ids = self._concept_tokenizer.encode(record["concept_ids"])
595616
record["input_ids"] = input_ids
596617
concept_value_masks = record["concept_value_masks"]
618+
record["epoch_times"] = construct_time_sequence(record["concept_ids"], record.get("epoch_times", None))
597619

598620
# These fields may not exist in the old version of the datasets
599621
if "units" in record:
@@ -651,6 +673,86 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
651673
return record
652674

653675

676+
class ExtractTokenizedSequenceDataMapping:
677+
def __init__(
678+
self,
679+
person_index_date_map: Dict[int, List[Dict[str, Any]]],
680+
observation_window: int = 0,
681+
):
682+
self.person_index_date_map = person_index_date_map
683+
self.observation_window = observation_window
684+
685+
def _calculate_prediction_start_time(self, prediction_time: float):
686+
if self.observation_window and self.observation_window > 0:
687+
return max(prediction_time - self.observation_window * 24 * 3600, 0)
688+
return 0
689+
690+
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
691+
person_id = record["person_id"]
692+
prediction_times = self.person_index_date_map[person_id]
693+
prediction_start_end_times = [
694+
(
695+
self._calculate_prediction_start_time(
696+
prediction_time_label_map["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
697+
),
698+
prediction_time_label_map["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp(),
699+
prediction_time_label_map["label"],
700+
)
701+
for prediction_time_label_map in prediction_times
702+
]
703+
observation_window_indices = np.zeros((len(prediction_times), len(record["epoch_times"])), dtype=bool)
704+
for i, epoch_time in enumerate(record["epoch_times"]):
705+
for sample_n, (
706+
feature_extraction_time_start,
707+
feature_extraction_end_end,
708+
_,
709+
) in enumerate(prediction_start_end_times):
710+
if feature_extraction_time_start <= epoch_time <= feature_extraction_end_end:
711+
observation_window_indices[sample_n][i] = True
712+
713+
seq_length = len(record["epoch_times"])
714+
time_series_columns = ["concept_ids", "input_ids"]
715+
static_inputs = dict()
716+
for k, v in record.items():
717+
if k in ["concept_ids", "input_ids"]:
718+
continue
719+
if isinstance(v, (list, np.ndarray)) and len(v) == seq_length:
720+
time_series_columns.append(k)
721+
else:
722+
static_inputs[k] = v
723+
724+
batched_samples = defaultdict(list)
725+
for (_, index_date, label), observation_window_index in zip(
726+
prediction_start_end_times, observation_window_indices
727+
):
728+
for k, v in static_inputs.items():
729+
batched_samples[k].append(v)
730+
batched_samples["classifier_label"].append(label)
731+
batched_samples["index_date"].append(index_date)
732+
try:
733+
start_age = int(record["concept_ids"][1].split(":")[1])
734+
except Exception:
735+
start_age = -1
736+
batched_samples["age_at_index"].append(start_age)
737+
for time_series_column in time_series_columns:
738+
batched_samples[time_series_column].append(
739+
np.asarray(record[time_series_column])[observation_window_index]
740+
)
741+
return batched_samples
742+
743+
def batch_transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
744+
all_batched_record = defaultdict(list)
745+
all_columns = record.keys()
746+
for i in range(len(record["concept_ids"])):
747+
one_record = {}
748+
for column in all_columns:
749+
one_record[column] = record[column][i]
750+
new_batched_record = self.transform(one_record)
751+
for k, v in new_batched_record.items():
752+
all_batched_record[k].extend(v)
753+
return all_batched_record
754+
755+
654756
class HFFineTuningMapping(DatasetMapping):
655757
"""Consider removing this transformation in the future."""
656758

src/cehrbert/linear_prob/compute_cehrbert_features.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from cehrbert.data_generators.hf_data_generator.sample_packing_sampler import SamplePackingBatchSampler
2323
from cehrbert.models.hf_models.hf_cehrbert import CehrBertForPreTraining
2424
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
25+
from cehrbert.runners.data_utils import extract_cohort_sequences
2526
from cehrbert.runners.hf_cehrbert_finetune_runner import prepare_finetune_dataset
2627
from cehrbert.runners.runner_util import generate_prepared_ds_path, parse_runner_args
2728

@@ -85,21 +86,24 @@ def main():
8586
LOG.info("Prepared dataset loaded from disk...")
8687

8788
if processed_dataset is None:
88-
# Organize them into a single DatasetDict
89-
final_splits = prepare_finetune_dataset(data_args, training_args, cache_file_collector)
90-
91-
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
92-
if not data_args.streaming:
93-
all_columns = final_splits["train"].column_names
94-
if "visit_concept_ids" in all_columns:
95-
final_splits = final_splits.remove_columns(["visit_concept_ids"])
96-
97-
processed_dataset = create_cehrbert_finetuning_dataset(
98-
dataset=final_splits,
99-
concept_tokenizer=cehrgpt_tokenizer,
100-
data_args=data_args,
101-
cache_file_collector=cache_file_collector,
102-
)
89+
if cehrbert_args.tokenized_full_dataset_path is not None:
90+
processed_dataset = extract_cohort_sequences(data_args, cehrbert_args, cache_file_collector)
91+
else:
92+
# Organize them into a single DatasetDict
93+
final_splits = prepare_finetune_dataset(data_args, training_args, cache_file_collector)
94+
95+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
96+
if not data_args.streaming:
97+
all_columns = final_splits["train"].column_names
98+
if "visit_concept_ids" in all_columns:
99+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
100+
101+
processed_dataset = create_cehrbert_finetuning_dataset(
102+
dataset=final_splits,
103+
concept_tokenizer=cehrgpt_tokenizer,
104+
data_args=data_args,
105+
cache_file_collector=cache_file_collector,
106+
)
103107
if not data_args.streaming:
104108
processed_dataset.save_to_disk(prepared_ds_path)
105109
processed_dataset.cleanup_cache_files()

0 commit comments

Comments
 (0)