44import itertools
55import re
66from abc import ABC , abstractmethod
7+ from collections import defaultdict
78from dataclasses import dataclass
89from enum import Enum
910from typing import Any , Dict , Generator , List , Optional , Union
1718from meds .schema import birth_code , death_code
1819from pandas import Series
1920
21+ from cehrbert .cehrbert_utils import construct_time_sequence
2022from cehrbert .med_extension .schema_extension import Event
2123from cehrbert .models .hf_models .tokenization_hf_cehrbert import CehrBertTokenizer
2224from cehrbert .runners .hf_runner_argument_dataclass import DataTrainingArguments
@@ -284,6 +286,7 @@ def remove_columns(self):
284286 def _update_cehrbert_record (
285287 cehrbert_record : Dict [str , Any ],
286288 code : str ,
289+ time : datetime .datetime ,
287290 visit_segment : int = 0 ,
288291 date : int = 0 ,
289292 age : int = - 1 ,
@@ -304,6 +307,7 @@ def _update_cehrbert_record(
304307 cehrbert_record ["concept_values" ].append (concept_value )
305308 cehrbert_record ["units" ].append (unit )
306309 cehrbert_record ["mlm_skip_values" ].append (mlm_skip_value )
310+ cehrbert_record ["epoch_times" ].append (time .replace (tzinfo = datetime .timezone .utc ).timestamp ())
307311
308312 def transform (self , record : Dict [str , Any ]) -> Dict [str , Any ]:
309313
@@ -320,6 +324,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
320324 "units" : [],
321325 "mlm_skip_values" : [],
322326 "visit_concept_ids" : [],
327+ "epoch_times" : [],
323328 }
324329 # Extract the demographic information
325330 birth_datetime = record ["birth_datetime" ]
@@ -340,7 +345,10 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
340345 year_str = f"year:{ str (first_visit_start_datetime .year )} "
341346 age_str = f"age:{ str (relativedelta (first_visit_start_datetime , birth_datetime ).years )} "
342347
343- self ._update_cehrbert_record (cehrbert_record , year_str )
348+ self ._update_cehrbert_record (
349+ cehrbert_record ,
350+ year_str ,
351+ )
344352 self ._update_cehrbert_record (cehrbert_record , age_str )
345353 self ._update_cehrbert_record (cehrbert_record , gender )
346354 self ._update_cehrbert_record (cehrbert_record , race )
@@ -377,6 +385,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
377385 cehrbert_record ,
378386 code = self ._time_token_function (time_delta ),
379387 visit_concept_order = i + 1 ,
388+ time = visit_start_datetime ,
380389 )
381390
382391 # Add the VS token to the patient timeline to mark the start of a visit
@@ -393,6 +402,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
393402 date = date ,
394403 visit_segment = visit_segment ,
395404 visit_concept_id = visit_type ,
405+ time = date_cursor ,
396406 )
397407
398408 if self ._include_auxiliary_token :
@@ -404,6 +414,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
404414 date = date ,
405415 visit_segment = visit_segment ,
406416 visit_concept_id = visit_type ,
417+ time = date_cursor ,
407418 )
408419 # Keep track of the existing outpatient events, we don't want to add them again
409420 existing_outpatient_events = list ()
@@ -450,6 +461,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
450461 visit_concept_order = i + 1 ,
451462 visit_segment = visit_segment ,
452463 visit_concept_id = visit_type ,
464+ time = date_cursor ,
453465 )
454466 else :
455467 # For outpatient visits, we use the visit time stamp to calculate age and time because we assume
@@ -471,6 +483,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
471483 concept_value = concept_value ,
472484 unit = unit ,
473485 mlm_skip_value = concept_value_mask ,
486+ time = date_cursor ,
474487 )
475488 existing_outpatient_events .append ((date , code , concept_value ))
476489
@@ -496,6 +509,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
496509 visit_concept_order = i + 1 ,
497510 visit_segment = visit_segment ,
498511 visit_concept_id = visit_type ,
512+ time = date_cursor ,
499513 )
500514
501515 # Reuse the age and date calculated for the last event in the patient timeline
@@ -507,6 +521,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
507521 visit_concept_order = i + 1 ,
508522 visit_segment = visit_segment ,
509523 visit_concept_id = visit_type ,
524+ time = date_cursor ,
510525 )
511526
512527 # Toggle visit_segment_indicator
@@ -519,11 +534,17 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
519534 cehrbert_record ["num_of_concepts" ] = len (cehrbert_record ["concept_ids" ])
520535 cehrbert_record ["num_of_visits" ] = len (visits )
521536
537+ if record .get ("index_date" , None ) is not None :
538+ cehrbert_record ["index_date" ] = record ["index_date" ].replace (tzinfo = datetime .timezone .utc ).timestamp ()
522539 if "label" in record :
523540 cehrbert_record ["label" ] = record ["label" ]
524541 if "age_at_index" in record :
525542 cehrbert_record ["age_at_index" ] = record ["age_at_index" ]
526543
544+ assert len (cehrbert_record ["epoch_times" ]) == len (
545+ cehrbert_record ["concept_ids" ]
546+ ), "The number of time stamps must match with the number of concepts in the sequence"
547+
527548 return cehrbert_record
528549
529550
@@ -594,6 +615,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
594615 input_ids = self ._concept_tokenizer .encode (record ["concept_ids" ])
595616 record ["input_ids" ] = input_ids
596617 concept_value_masks = record ["concept_value_masks" ]
618+ record ["epoch_times" ] = construct_time_sequence (record ["concept_ids" ], record .get ("epoch_times" , None ))
597619
598620 # These fields may not exist in the old version of the datasets
599621 if "units" in record :
@@ -651,6 +673,86 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
651673 return record
652674
653675
676+ class ExtractTokenizedSequenceDataMapping :
677+ def __init__ (
678+ self ,
679+ person_index_date_map : Dict [int , List [Dict [str , Any ]]],
680+ observation_window : int = 0 ,
681+ ):
682+ self .person_index_date_map = person_index_date_map
683+ self .observation_window = observation_window
684+
685+ def _calculate_prediction_start_time (self , prediction_time : float ):
686+ if self .observation_window and self .observation_window > 0 :
687+ return max (prediction_time - self .observation_window * 24 * 3600 , 0 )
688+ return 0
689+
690+ def transform (self , record : Dict [str , Any ]) -> Dict [str , Any ]:
691+ person_id = record ["person_id" ]
692+ prediction_times = self .person_index_date_map [person_id ]
693+ prediction_start_end_times = [
694+ (
695+ self ._calculate_prediction_start_time (
696+ prediction_time_label_map ["index_date" ].replace (tzinfo = datetime .timezone .utc ).timestamp ()
697+ ),
698+ prediction_time_label_map ["index_date" ].replace (tzinfo = datetime .timezone .utc ).timestamp (),
699+ prediction_time_label_map ["label" ],
700+ )
701+ for prediction_time_label_map in prediction_times
702+ ]
703+ observation_window_indices = np .zeros ((len (prediction_times ), len (record ["epoch_times" ])), dtype = bool )
704+ for i , epoch_time in enumerate (record ["epoch_times" ]):
705+ for sample_n , (
706+ feature_extraction_time_start ,
707+ feature_extraction_end_end ,
708+ _ ,
709+ ) in enumerate (prediction_start_end_times ):
710+ if feature_extraction_time_start <= epoch_time <= feature_extraction_end_end :
711+ observation_window_indices [sample_n ][i ] = True
712+
713+ seq_length = len (record ["epoch_times" ])
714+ time_series_columns = ["concept_ids" , "input_ids" ]
715+ static_inputs = dict ()
716+ for k , v in record .items ():
717+ if k in ["concept_ids" , "input_ids" ]:
718+ continue
719+ if isinstance (v , (list , np .ndarray )) and len (v ) == seq_length :
720+ time_series_columns .append (k )
721+ else :
722+ static_inputs [k ] = v
723+
724+ batched_samples = defaultdict (list )
725+ for (_ , index_date , label ), observation_window_index in zip (
726+ prediction_start_end_times , observation_window_indices
727+ ):
728+ for k , v in static_inputs .items ():
729+ batched_samples [k ].append (v )
730+ batched_samples ["classifier_label" ].append (label )
731+ batched_samples ["index_date" ].append (index_date )
732+ try :
733+ start_age = int (record ["concept_ids" ][1 ].split (":" )[1 ])
734+ except Exception :
735+ start_age = - 1
736+ batched_samples ["age_at_index" ].append (start_age )
737+ for time_series_column in time_series_columns :
738+ batched_samples [time_series_column ].append (
739+ np .asarray (record [time_series_column ])[observation_window_index ]
740+ )
741+ return batched_samples
742+
743+ def batch_transform (self , record : Dict [str , Any ]) -> Dict [str , Any ]:
744+ all_batched_record = defaultdict (list )
745+ all_columns = record .keys ()
746+ for i in range (len (record ["concept_ids" ])):
747+ one_record = {}
748+ for column in all_columns :
749+ one_record [column ] = record [column ][i ]
750+ new_batched_record = self .transform (one_record )
751+ for k , v in new_batched_record .items ():
752+ all_batched_record [k ].extend (v )
753+ return all_batched_record
754+
755+
654756class HFFineTuningMapping (DatasetMapping ):
655757 """Consider removing this transformation in the future."""
656758
0 commit comments