Skip to content

Commit 63d2104

Browse files
committed
updated the data collator sample packing
1 parent 0b364e1 commit 63d2104

File tree

1 file changed

+26
-71
lines changed

1 file changed

+26
-71
lines changed

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __call__(self, examples):
118118
batch["visit_segments"] = torch.cat([torch.full((batch_size, 1), 0), batch["visit_segments"]], dim=1)
119119
else:
120120
assert (
121-
batch["attention_mask"].shape[0] == 1
121+
batch["attention_mask"].shape[0] == 1
122122
), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
123123

124124
# This is the most crucial logic for generating the training labels
@@ -273,8 +273,6 @@ def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
273273
super(SamplePackingCehrBertDataCollator, self).__init__(*args, **kwargs)
274274

275275
def __call__(self, examples):
276-
flattened_examples = []
277-
278276
# Main inputs
279277
current_input_ids = []
280278
current_attention_mask = []
@@ -299,50 +297,6 @@ def __call__(self, examples):
299297
example = self.generate_start_end_index(example, self.max_position_embeddings)
300298

301299
input_ids = example["input_ids"]
302-
# We add the flattened example to the list either when the example exceeds the total max tokens
303-
# we add the length by two because we need to add two more tokens [CLS] .... [PAD]
304-
if len(current_input_ids) + len(input_ids) + 2 > self.max_tokens_per_batch and current_input_ids:
305-
packed_example = {
306-
"input_ids": current_input_ids,
307-
"attention_mask": current_attention_mask,
308-
"ages": current_ages,
309-
"dates": current_dates,
310-
"visit_concept_orders": current_visit_concept_orders,
311-
"concept_values": current_concept_values,
312-
"concept_value_masks": current_concept_value_masks,
313-
"visit_segments": current_visit_segments,
314-
}
315-
316-
if current_labels:
317-
packed_example.update(
318-
{
319-
"person_id": current_person_ids,
320-
"index_date": current_index_dates,
321-
"age_at_index": current_age_at_indexes,
322-
"classifier_label": current_labels,
323-
}
324-
)
325-
326-
flattened_examples.append(packed_example)
327-
328-
# Main inputs
329-
current_input_ids = []
330-
current_attention_mask = []
331-
current_concept_values = []
332-
current_concept_value_masks = []
333-
current_ages = []
334-
current_dates = []
335-
current_visit_concept_orders = []
336-
current_visit_segments = []
337-
338-
# Demographics
339-
current_person_ids = []
340-
current_index_dates = []
341-
342-
# Binary classification inputs
343-
current_age_at_indexes = []
344-
current_labels = []
345-
346300
current_input_ids.extend([self.tokenizer.cls_token_index] + input_ids + [self.tokenizer.pad_token_index])
347301
current_attention_mask.extend([1] + np.ones_like(input_ids).tolist() + [0])
348302
current_concept_values.extend([-1] + example["concept_values"] + [-1])
@@ -368,29 +322,30 @@ def __call__(self, examples):
368322
if "classifier_label" in example:
369323
current_labels.append(example["classifier_label"])
370324

371-
# The final batch needs to be added
372-
if current_input_ids:
373-
packed_example = {
374-
"input_ids": current_input_ids,
375-
"attention_mask": current_attention_mask,
376-
"ages": current_ages,
377-
"dates": current_dates,
378-
"visit_concept_orders": current_visit_concept_orders,
379-
"concept_values": current_concept_values,
380-
"concept_value_masks": current_concept_value_masks,
381-
"visit_segments": current_visit_segments,
382-
}
383-
384-
if current_labels:
385-
packed_example.update(
386-
{
387-
"person_id": current_person_ids,
388-
"index_date": current_index_dates,
389-
"age_at_index": current_age_at_indexes,
390-
"classifier_label": current_labels,
391-
}
392-
)
325+
assert len(current_input_ids) <= self.max_tokens_per_batch, (
326+
"len(current_input_ids) must be less than and equal to self.max_tokens_per_batch, "
327+
f"but received {len(current_input_ids)} instead"
328+
)
393329

394-
flattened_examples.append(packed_example)
330+
packed_example = {
331+
"input_ids": current_input_ids,
332+
"attention_mask": current_attention_mask,
333+
"ages": current_ages,
334+
"dates": current_dates,
335+
"visit_concept_orders": current_visit_concept_orders,
336+
"concept_values": current_concept_values,
337+
"concept_value_masks": current_concept_value_masks,
338+
"visit_segments": current_visit_segments,
339+
}
340+
341+
if current_labels:
342+
packed_example.update(
343+
{
344+
"person_id": current_person_ids,
345+
"index_date": current_index_dates,
346+
"age_at_index": current_age_at_indexes,
347+
"classifier_label": current_labels,
348+
}
349+
)
395350

396-
return super().__call__(flattened_examples)
351+
return super().__call__([packed_example])

0 commit comments

Comments
 (0)