Skip to content

Commit f2b6989

Browse files
committed
do not mark timestamps as UTC time
1 parent fb5d8ab commit f2b6989

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

src/cehrbert/cehrbert_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def construct_time_sequence(
2222
else:
2323
year_str = "1985"
2424

25-
datetime_cursor = datetime(int(year_str), month=1, day=1, hour=0, minute=0, second=0).replace(tzinfo=timezone.utc)
25+
datetime_cursor = datetime(int(year_str), month=1, day=1, hour=0, minute=0, second=0)
2626
epoch_times = []
2727
for concept_id in concept_ids:
2828
if is_att_token(concept_id):

src/cehrbert/data_generators/hf_data_generator/hf_dataset_mapping.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def _update_cehrbert_record(
307307
cehrbert_record["concept_values"].append(concept_value)
308308
cehrbert_record["units"].append(unit)
309309
cehrbert_record["mlm_skip_values"].append(mlm_skip_value)
310-
cehrbert_record["epoch_times"].append(time.replace(tzinfo=datetime.timezone.utc).timestamp())
310+
cehrbert_record["epoch_times"].append(time.timestamp())
311311

312312
def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
313313

@@ -532,7 +532,7 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
532532
cehrbert_record["num_of_visits"] = len(visits)
533533

534534
if record.get("index_date", None) is not None:
535-
cehrbert_record["index_date"] = record["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
535+
cehrbert_record["index_date"] = record["index_date"].timestamp()
536536
if "label" in record:
537537
cehrbert_record["label"] = record["label"]
538538
if "age_at_index" in record:
@@ -690,9 +690,9 @@ def transform(self, record: Dict[str, Any]) -> Dict[str, Any]:
690690
prediction_start_end_times = [
691691
(
692692
self._calculate_prediction_start_time(
693-
prediction_time_label_map["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp()
693+
prediction_time_label_map["index_date"].timestamp()
694694
),
695-
prediction_time_label_map["index_date"].replace(tzinfo=datetime.timezone.utc).timestamp(),
695+
prediction_time_label_map["index_date"].timestamp(),
696696
prediction_time_label_map["label"],
697697
)
698698
for prediction_time_label_map in prediction_times

0 commit comments

Comments
 (0)