Skip to content

Commit e447ddf

Browse files
committed
fixed the case where the input_ids is 0
1 parent cfc93c3 commit e447ddf

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,17 @@ def __call__(self, examples):
301301
current_attention_mask.extend([1] + np.ones_like(input_ids).tolist() + [0])
302302
current_concept_values.extend([-1] + example["concept_values"] + [-1])
303303
current_concept_value_masks.extend([0] + example["concept_value_masks"] + [0])
304-
current_ages.extend([example["ages"][0]] + example["ages"] + [0])
305-
current_dates.extend([example["dates"][0]] + example["dates"] + [0])
304+
305+
if len(example["ages"]) > 0:
306+
current_ages.extend([example["ages"][0]] + example["ages"] + [0])
307+
else:
308+
current_ages.extend([0, 0])
309+
310+
if len(example["dates"]) > 0:
311+
current_dates.extend([example["dates"][0]] + example["dates"] + [0])
312+
else:
313+
current_dates.extend([0, 0])
314+
306315
current_visit_concept_orders.extend(
307316
[max(0, example["visit_concept_orders"][0] - 1)]
308317
+ example["visit_concept_orders"]

0 commit comments

Comments
 (0)