Skip to content

Commit d7ab89c

Browse files
committed
added observation_window to the meds data conversion
1 parent 34b71e1 commit d7ab89c

File tree

3 files changed

+50
-16
lines changed

3 files changed

+50
-16
lines changed

src/cehrbert/data_generators/hf_data_generator/meds_utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def convert_one_patient(
5454
patient: meds_reader.Subject,
5555
conversion: MedsToCehrBertConversion,
5656
prediction_time: datetime = None,
57+
observation_window: int = None,
5758
label: Union[int, float] = None,
5859
) -> CehrBertPatient:
5960
"""
@@ -80,6 +81,9 @@ def convert_one_patient(
8081
The cutoff time for processing events. Events occurring after this time are
8182
ignored.
8283
84+
observation_window: int, optional (default=None)
85+
The observation window for extracting features
86+
8387
label : Union[int, float], optional (default=None)
8488
The prediction label associated with this patient, which could represent a
8589
clinical outcome (e.g., survival or treatment response).
@@ -122,9 +126,7 @@ def convert_one_patient(
122126
)
123127
"""
124128
demographics, patient_blocks = generate_demographics_and_patient_blocks(
125-
conversion=conversion,
126-
patient=patient,
127-
prediction_time=prediction_time,
129+
conversion=conversion, patient=patient, prediction_time=prediction_time, observation_window=observation_window
128130
)
129131

130132
patient_block_dict = collections.defaultdict(list)
@@ -230,9 +232,9 @@ def _meds_to_cehrbert_generator(
230232
)
231233
with meds_reader.SubjectDatabase(path_to_db) as patient_database:
232234
for shard in shards:
233-
for patient_id, prediction_time, label in shard:
235+
for patient_id, prediction_time, observation_window, label in shard:
234236
patient = patient_database[patient_id]
235-
converted_patient = convert_one_patient(patient, conversion, prediction_time, label)
237+
converted_patient = convert_one_patient(patient, conversion, prediction_time, observation_window, label)
236238
# there are patients whose birthdate is none
237239
visits = converted_patient["visits"]
238240
if converted_patient["birth_datetime"] is None:
@@ -262,11 +264,11 @@ def _create_cehrbert_data_from_meds(
262264
subject_id = cohort_row.subject_id
263265
prediction_time = cohort_row.prediction_time
264266
label = int(cohort_row.boolean_value)
265-
batches.append((subject_id, prediction_time, label))
267+
batches.append((subject_id, prediction_time, data_args.observation_window, label))
266268
else:
267269
patient_split = get_subject_split(os.path.expanduser(data_args.data_folder))
268270
for subject_id in patient_split[split]:
269-
batches.append((subject_id, None, None))
271+
batches.append((subject_id, None, None, None))
270272

271273
features = Features(
272274
{

src/cehrbert/data_generators/hf_data_generator/patient_block.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from collections import defaultdict
33
from dataclasses import dataclass
4-
from datetime import datetime
4+
from datetime import datetime, timedelta
55
from typing import Iterable, List, Optional, Tuple
66

77
import meds_reader
@@ -321,19 +321,32 @@ def generate_demographics_and_patient_blocks(
321321
conversion: MedsToCehrBertConversion,
322322
patient: meds_reader.Subject,
323323
prediction_time: datetime = None,
324+
observation_window: int = None,
324325
) -> Tuple[PatientDemographics, List[PatientBlock]]:
325326
if isinstance(conversion, MedsToBertMimic4):
326327
return mimic_meds_generate_demographics_and_patient_blocks(
327-
patient, conversion, prediction_time, conversion.default_visit_id
328+
patient=patient,
329+
conversion=conversion,
330+
prediction_time=prediction_time,
331+
observation_window=observation_window,
332+
default_visit_id=conversion.default_visit_id,
328333
)
329334
elif isinstance(conversion, MedsToCehrbertOMOP):
330-
return omop_meds_generate_demographics_and_patient_blocks(patient, conversion, prediction_time)
335+
return omop_meds_generate_demographics_and_patient_blocks(
336+
patient=patient,
337+
conversion=conversion,
338+
prediction_time=prediction_time,
339+
observation_window=observation_window,
340+
)
331341
else:
332342
raise RuntimeError(f"Unrecognized conversion: {conversion}")
333343

334344

335345
def omop_meds_generate_demographics_and_patient_blocks(
336-
patient: meds_reader.Subject, conversion: MedsToCehrBertConversion, prediction_time: datetime = None
346+
patient: meds_reader.Subject,
347+
conversion: MedsToCehrBertConversion,
348+
prediction_time: datetime = None,
349+
observation_window: int = None,
337350
) -> Tuple[PatientDemographics, List[PatientBlock]]:
338351
disconnect_problem_list_events = getattr(conversion, "disconnect_problem_list_events", False)
339352
birth_datetime = None
@@ -342,6 +355,9 @@ def omop_meds_generate_demographics_and_patient_blocks(
342355
ethnicity = None
343356
visit_events = defaultdict(list)
344357
unlinked_event_mapping = defaultdict(list)
358+
observation_start_window: Optional[datetime] = None
359+
if prediction_time and observation_window:
360+
observation_start_window = prediction_time - timedelta(days=observation_window)
345361
for e in patient.events:
346362
# This indicates demographics features
347363
event_code_uppercase = e.code.upper()
@@ -355,10 +371,11 @@ def omop_meds_generate_demographics_and_patient_blocks(
355371
ethnicity = e.code
356372
elif e.time is not None:
357373
# Skip out of the loop if the events' time stamps are beyond the prediction time
358-
if prediction_time is not None:
359-
if e.time > prediction_time:
360-
break
361-
374+
if prediction_time is not None and e.time > prediction_time:
375+
break
376+
# Skip out of the loop if the events' time stamps are before the observation start window
377+
if observation_start_window is not None and e.time < observation_start_window:
378+
break
362379
if getattr(e, "visit_id", None):
363380
visit_id = e.visit_id
364381
visit_events[visit_id].append(e)
@@ -528,6 +545,7 @@ def mimic_meds_generate_demographics_and_patient_blocks(
528545
patient: meds_reader.Subject,
529546
conversion: MedsToCehrBertConversion,
530547
prediction_time: datetime = None,
548+
observation_window: int = None,
531549
default_visit_id: int = 1,
532550
) -> Tuple[PatientDemographics, List[PatientBlock]]:
533551
birth_datetime = None
@@ -539,13 +557,21 @@ def mimic_meds_generate_demographics_and_patient_blocks(
539557
current_date = None
540558
events_for_current_date = []
541559
patient_blocks = []
542-
for e in patient.events:
560+
observation_start_window: Optional[datetime] = None
561+
if prediction_time and observation_window:
562+
observation_start_window = prediction_time - timedelta(days=observation_window)
543563

564+
for e in patient.events:
544565
# Skip out of the loop if the events' time stamps are beyond the prediction time
545566
if prediction_time is not None and e.time is not None:
546567
if e.time > prediction_time:
547568
break
548569

570+
# Skip out of the loop if the events' time stamps are before observation start window
571+
if observation_start_window is not None and e.time is not None:
572+
if e.time < observation_start_window:
573+
break
574+
549575
# This indicates demographics features
550576
event_code_uppercase = e.code.upper()
551577
if event_code_uppercase.startswith(birth_code):

src/cehrbert/runners/hf_runner_argument_dataclass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ class DataTrainingArguments:
150150
"This depends on how the EHR records are generated, default to False"
151151
},
152152
)
153+
observation_window: Optional[int] = dataclasses.field(
154+
default=None,
155+
metadata={
156+
"help": "The observation window to use in the preprocessing.",
157+
},
158+
)
153159
streaming: Optional[bool] = dataclasses.field(
154160
default=False,
155161
metadata={"help": "The boolean indicator to indicate whether the data should be streamed"},

0 commit comments

Comments
 (0)