@@ -118,7 +118,7 @@ def __call__(self, examples):
118118 batch ["visit_segments" ] = torch .cat ([torch .full ((batch_size , 1 ), 0 ), batch ["visit_segments" ]], dim = 1 )
119119 else :
120120 assert (
121- batch ["attention_mask" ].shape [0 ] == 1
121+ batch ["attention_mask" ].shape [0 ] == 1
122122 ), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
123123
124124 # This is the most crucial logic for generating the training labels
@@ -273,8 +273,6 @@ def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
273273 super (SamplePackingCehrBertDataCollator , self ).__init__ (* args , ** kwargs )
274274
275275 def __call__ (self , examples ):
276- flattened_examples = []
277-
278276 # Main inputs
279277 current_input_ids = []
280278 current_attention_mask = []
@@ -299,50 +297,6 @@ def __call__(self, examples):
299297 example = self .generate_start_end_index (example , self .max_position_embeddings )
300298
301299 input_ids = example ["input_ids" ]
302- # We add the flattened example to the list either when the example exceeds the total max tokens
303- # we add the length by two because we need to add two more tokens [CLS] .... [PAD]
304- if len (current_input_ids ) + len (input_ids ) + 2 > self .max_tokens_per_batch and current_input_ids :
305- packed_example = {
306- "input_ids" : current_input_ids ,
307- "attention_mask" : current_attention_mask ,
308- "ages" : current_ages ,
309- "dates" : current_dates ,
310- "visit_concept_orders" : current_visit_concept_orders ,
311- "concept_values" : current_concept_values ,
312- "concept_value_masks" : current_concept_value_masks ,
313- "visit_segments" : current_visit_segments ,
314- }
315-
316- if current_labels :
317- packed_example .update (
318- {
319- "person_id" : current_person_ids ,
320- "index_date" : current_index_dates ,
321- "age_at_index" : current_age_at_indexes ,
322- "classifier_label" : current_labels ,
323- }
324- )
325-
326- flattened_examples .append (packed_example )
327-
328- # Main inputs
329- current_input_ids = []
330- current_attention_mask = []
331- current_concept_values = []
332- current_concept_value_masks = []
333- current_ages = []
334- current_dates = []
335- current_visit_concept_orders = []
336- current_visit_segments = []
337-
338- # Demographics
339- current_person_ids = []
340- current_index_dates = []
341-
342- # Binary classification inputs
343- current_age_at_indexes = []
344- current_labels = []
345-
346300 current_input_ids .extend ([self .tokenizer .cls_token_index ] + input_ids + [self .tokenizer .pad_token_index ])
347301 current_attention_mask .extend ([1 ] + np .ones_like (input_ids ).tolist () + [0 ])
348302 current_concept_values .extend ([- 1 ] + example ["concept_values" ] + [- 1 ])
@@ -368,29 +322,30 @@ def __call__(self, examples):
368322 if "classifier_label" in example :
369323 current_labels .append (example ["classifier_label" ])
370324
371- # The final batch needs to be added
372- if current_input_ids :
373- packed_example = {
374- "input_ids" : current_input_ids ,
375- "attention_mask" : current_attention_mask ,
376- "ages" : current_ages ,
377- "dates" : current_dates ,
378- "visit_concept_orders" : current_visit_concept_orders ,
379- "concept_values" : current_concept_values ,
380- "concept_value_masks" : current_concept_value_masks ,
381- "visit_segments" : current_visit_segments ,
382- }
383-
384- if current_labels :
385- packed_example .update (
386- {
387- "person_id" : current_person_ids ,
388- "index_date" : current_index_dates ,
389- "age_at_index" : current_age_at_indexes ,
390- "classifier_label" : current_labels ,
391- }
392- )
325+ assert len (current_input_ids ) <= self .max_tokens_per_batch , (
326+ "len(current_input_ids) must be less than and equal to self.max_tokens_per_batch, "
327+ f"but received { len (current_input_ids )} instead"
328+ )
393329
394- flattened_examples .append (packed_example )
330+ packed_example = {
331+ "input_ids" : current_input_ids ,
332+ "attention_mask" : current_attention_mask ,
333+ "ages" : current_ages ,
334+ "dates" : current_dates ,
335+ "visit_concept_orders" : current_visit_concept_orders ,
336+ "concept_values" : current_concept_values ,
337+ "concept_value_masks" : current_concept_value_masks ,
338+ "visit_segments" : current_visit_segments ,
339+ }
340+
341+ if current_labels :
342+ packed_example .update (
343+ {
344+ "person_id" : current_person_ids ,
345+ "index_date" : current_index_dates ,
346+ "age_at_index" : current_age_at_indexes ,
347+ "classifier_label" : current_labels ,
348+ }
349+ )
395350
396- return super ().__call__ (flattened_examples )
351+ return super ().__call__ ([ packed_example ] )
0 commit comments