11import collections
22import random
3- from typing import Any , Dict , Tuple
3+ from typing import Any , Dict , Optional , Tuple
44
55import numpy as np
66import torch
@@ -52,7 +52,12 @@ def __call__(self, examples):
5252 # Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
5353 batch_input_ids = [self ._convert_to_tensor (example ["input_ids" ]) for example in examples ]
5454 batch_attention_mask = [
55- torch .ones_like (self ._convert_to_tensor (example ["input_ids" ]), dtype = torch .float ) for example in examples
55+ (
56+ self ._convert_to_tensor (example ["attention_mask" ]).to (torch .float )
57+ if "attention_mask" in example
58+ else torch .ones_like (self ._convert_to_tensor (example ["input_ids" ]), dtype = torch .float )
59+ )
60+ for example in examples
5661 ]
5762 batch_ages = [self ._convert_to_tensor (example ["ages" ]) for example in examples ]
5863 batch_dates = [self ._convert_to_tensor (example ["dates" ]) for example in examples ]
@@ -79,35 +84,42 @@ def __call__(self, examples):
7984 batch ["concept_value_masks" ] = pad_sequence (batch_concept_value_masks , batch_first = True , padding_value = 0.0 )
8085 batch ["visit_segments" ] = pad_sequence (batch_visit_segments , batch_first = True , padding_value = 0 )
8186
82- # Prepend the CLS token and their associated values to the corresponding time series features
83- batch ["input_ids" ] = torch .cat (
84- [
85- torch .full ((batch_size , 1 ), self .tokenizer .cls_token_index ),
86- batch ["input_ids" ],
87- ],
88- dim = 1 ,
89- )
90- # The attention_mask is set to 1 to enable attention for the CLS token
91- batch ["attention_mask" ] = torch .cat ([torch .full ((batch_size , 1 ), 1.0 ), batch ["attention_mask" ]], dim = 1 )
92- # Set the age of the CLS token to the starting age
93- batch ["ages" ] = torch .cat ([batch ["ages" ][:, 0 :1 ], batch ["ages" ]], dim = 1 )
94- # Set the age of the CLS token to the starting date
95- batch ["dates" ] = torch .cat ([batch ["dates" ][:, 0 :1 ], batch ["dates" ]], dim = 1 )
96- # Set the visit_concept_order of the CLS token to the first visit_concept_order in the sequence subtract by 1
97- visit_concept_orders_first = batch ["visit_concept_orders" ][:, 0 :1 ] - 1
98- visit_concept_orders_first = torch .maximum (
99- visit_concept_orders_first , torch .zeros_like (visit_concept_orders_first )
100- )
101- batch ["visit_concept_orders" ] = torch .cat ([visit_concept_orders_first , batch ["visit_concept_orders" ]], dim = 1 )
102- # Set the concept_value of the CLS token to a default value -1.0.
103- batch ["concept_values" ] = torch .cat ([torch .full ((batch_size , 1 ), - 1.0 ), batch ["concept_values" ]], dim = 1 )
104- # Set the concept_value of the CLS token to a default value 0.0 indicating that
105- # there is no value associated with this token
106- batch ["concept_value_masks" ] = torch .cat (
107- [torch .full ((batch_size , 1 ), 0.0 ), batch ["concept_value_masks" ]], dim = 1
108- )
109- # Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit
110- batch ["visit_segments" ] = torch .cat ([torch .full ((batch_size , 1 ), 0 ), batch ["visit_segments" ]], dim = 1 )
87+ if not getattr (self , "sample_packing" , False ):
88+ # Prepend the CLS token and their associated values to the corresponding time series features
89+ batch ["input_ids" ] = torch .cat (
90+ [
91+ torch .full ((batch_size , 1 ), self .tokenizer .cls_token_index ),
92+ batch ["input_ids" ],
93+ ],
94+ dim = 1 ,
95+ )
96+ # The attention_mask is set to 1 to enable attention for the CLS token
97+ batch ["attention_mask" ] = torch .cat ([torch .full ((batch_size , 1 ), 1.0 ), batch ["attention_mask" ]], dim = 1 )
98+ # Set the age of the CLS token to the starting age
99+ batch ["ages" ] = torch .cat ([batch ["ages" ][:, 0 :1 ], batch ["ages" ]], dim = 1 )
100+ # Set the age of the CLS token to the starting date
101+ batch ["dates" ] = torch .cat ([batch ["dates" ][:, 0 :1 ], batch ["dates" ]], dim = 1 )
102+ # Set the visit_concept_order of the CLS token to the first visit_concept_order in the sequence subtract by 1
103+ visit_concept_orders_first = batch ["visit_concept_orders" ][:, 0 :1 ] - 1
104+ visit_concept_orders_first = torch .maximum (
105+ visit_concept_orders_first , torch .zeros_like (visit_concept_orders_first )
106+ )
107+ batch ["visit_concept_orders" ] = torch .cat (
108+ [visit_concept_orders_first , batch ["visit_concept_orders" ]], dim = 1
109+ )
110+ # Set the concept_value of the CLS token to a default value -1.0.
111+ batch ["concept_values" ] = torch .cat ([torch .full ((batch_size , 1 ), - 1.0 ), batch ["concept_values" ]], dim = 1 )
112+ # Set the concept_value of the CLS token to a default value 0.0 indicating that
113+ # there is no value associated with this token
114+ batch ["concept_value_masks" ] = torch .cat (
115+ [torch .full ((batch_size , 1 ), 0.0 ), batch ["concept_value_masks" ]], dim = 1
116+ )
117+ # Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit
118+ batch ["visit_segments" ] = torch .cat ([torch .full ((batch_size , 1 ), 0 ), batch ["visit_segments" ]], dim = 1 )
119+ else :
120+ assert (
121+ batch ["attention_mask" ].shape [0 ] == 1
122+ ), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
111123
112124 # This is the most crucial logic for generating the training labels
113125 if self .is_pretraining :
@@ -125,29 +137,46 @@ def __call__(self, examples):
125137
126138 batch ["input_ids" ], batch ["labels" ] = self .torch_mask_tokens (batch ["input_ids" ], batch ["labels" ])
127139
140+ bz = len (examples )
128141 if "person_id" in examples [0 ]:
129- batch ["person_id" ] = torch .cat (
130- [self ._convert_to_tensor (example ["person_id" ]).reshape (- 1 , 1 ) for example in examples ],
131- dim = 0 ,
132- ).to (torch .float )
142+ batch ["person_id" ] = (
143+ torch .cat (
144+ [self ._convert_to_tensor (example ["person_id" ]).reshape (- 1 , 1 ) for example in examples ],
145+ dim = 0 ,
146+ )
147+ .to (torch .int32 )
148+ .reshape (bz , - 1 )
149+ )
133150
134151 if "index_date" in examples [0 ]:
135- batch ["index_date" ] = torch .cat (
136- [self ._convert_to_tensor (example ["index_date" ]).reshape (- 1 , 1 ) for example in examples ],
137- dim = 0 ,
138- ).to (torch .float32 )
152+ batch ["index_date" ] = (
153+ torch .cat (
154+ [self ._convert_to_tensor (example ["index_date" ]).reshape (- 1 , 1 ) for example in examples ],
155+ dim = 0 ,
156+ )
157+ .to (torch .float64 )
158+ .reshape (bz , - 1 )
159+ )
139160
140161 if "age_at_index" in examples [0 ]:
141- batch ["age_at_index" ] = torch .cat (
142- [self ._convert_to_tensor (example ["age_at_index" ]).reshape (- 1 , 1 ) for example in examples ],
143- dim = 0 ,
144- ).to (torch .float )
162+ batch ["age_at_index" ] = (
163+ torch .cat (
164+ [self ._convert_to_tensor (example ["age_at_index" ]).reshape (- 1 , 1 ) for example in examples ],
165+ dim = 0 ,
166+ )
167+ .to (torch .float32 )
168+ .reshape (bz , - 1 )
169+ )
145170
146171 if "classifier_label" in examples [0 ]:
147- batch ["classifier_label" ] = torch .cat (
148- [self ._convert_to_tensor (example ["classifier_label" ]).reshape (- 1 , 1 ) for example in examples ],
149- dim = 0 ,
150- ).to (torch .float )
172+ batch ["classifier_label" ] = (
173+ torch .cat (
174+ [self ._convert_to_tensor (example ["classifier_label" ]).reshape (- 1 , 1 ) for example in examples ],
175+ dim = 0 ,
176+ )
177+ .to (torch .float )
178+ .reshape (bz , - 1 )
179+ )
151180
152181 return batch
153182
@@ -173,14 +202,19 @@ def torch_mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple
173202 # The rest of the time (10% of the time) we keep the masked input tokens unchanged
174203 return inputs , labels
175204
176- def generate_start_end_index (self , record : Dict [str , Any ]) -> Dict [str , Any ]:
205+ def generate_start_end_index (
206+ self , record : Dict [str , Any ], max_length_allowed : Optional [int ] = None
207+ ) -> Dict [str , Any ]:
177208 """
178209 Adapted from https://github.com/OHDSI/Apollo/blob/main/data_loading/data_transformer.py.
179210
180211 Adding the start and end indices to extract a portion of the patient sequence
181212 """
213+ sample_packing = getattr (self , "sample_packing" , False )
214+ max_length_allowed = self .max_length if max_length_allowed is None else max_length_allowed
182215 seq_length = len (record ["input_ids" ])
183- new_max_length = self .max_length - 1 # Subtract one for the [CLS] token
216+ # Subtract one for the [CLS] token
217+ new_max_length = max_length_allowed if sample_packing else max_length_allowed - 1
184218
185219 # Return the record directly if the actual sequence length is less than the max sequence
186220 if seq_length <= new_max_length :
@@ -229,3 +263,134 @@ def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
229263 else :
230264 new_record [k ] = v
231265 return new_record
266+
267+
268+ class SamplePackingCehrBertDataCollator (CehrBertDataCollator ):
269+ def __init__ (self , max_tokens , max_position_embeddings , * args , ** kwargs ):
270+ self .max_tokens_per_batch = max_tokens
271+ self .max_position_embeddings = max_position_embeddings
272+ self .sample_packing = True
273+ super (SamplePackingCehrBertDataCollator , self ).__init__ (* args , ** kwargs )
274+
275+ def __call__ (self , examples ):
276+ flattened_examples = []
277+
278+ # Main inputs
279+ current_input_ids = []
280+ current_attention_mask = []
281+ current_concept_values = []
282+ current_concept_value_masks = []
283+ current_ages = []
284+ current_dates = []
285+ current_visit_concept_orders = []
286+ current_visit_segments = []
287+
288+ # Demographics
289+ current_person_ids = []
290+ current_index_dates = []
291+
292+ # Binary classification inputs
293+ current_age_at_indexes = []
294+ current_labels = []
295+
296+ for idx , example in enumerate (examples ):
297+ # If the sample length exceeds the model's capacity, truncate this example
298+ if len (example ["input_ids" ]) > self .max_position_embeddings :
299+ example = self .generate_start_end_index (example , self .max_position_embeddings )
300+
301+ input_ids = example ["input_ids" ]
302+ # We add the flattened example to the list either when the example exceeds the total max tokens
303+ # we add the length by two because we need to add two more tokens [CLS] .... [PAD]
304+ if len (current_input_ids ) + len (input_ids ) + 2 > self .max_tokens_per_batch and current_input_ids :
305+ packed_example = {
306+ "input_ids" : current_input_ids ,
307+ "attention_mask" : current_attention_mask ,
308+ "ages" : current_ages ,
309+ "dates" : current_dates ,
310+ "visit_concept_orders" : current_visit_concept_orders ,
311+ "concept_values" : current_concept_values ,
312+ "concept_value_masks" : current_concept_value_masks ,
313+ "visit_segments" : current_visit_segments ,
314+ }
315+
316+ if current_labels :
317+ packed_example .update (
318+ {
319+ "person_id" : current_person_ids ,
320+ "index_date" : current_index_dates ,
321+ "age_at_index" : current_age_at_indexes ,
322+ "classifier_label" : current_labels ,
323+ }
324+ )
325+
326+ flattened_examples .append (packed_example )
327+
328+ # Main inputs
329+ current_input_ids = []
330+ current_attention_mask = []
331+ current_concept_values = []
332+ current_concept_value_masks = []
333+ current_ages = []
334+ current_dates = []
335+ current_visit_concept_orders = []
336+ current_visit_segments = []
337+
338+ # Demographics
339+ current_person_ids = []
340+ current_index_dates = []
341+
342+ # Binary classification inputs
343+ current_age_at_indexes = []
344+ current_labels = []
345+
346+ current_input_ids .extend ([self .tokenizer .cls_token_index ] + input_ids + [self .tokenizer .pad_token_index ])
347+ current_attention_mask .extend ([1 ] + np .ones_like (input_ids ).tolist () + [0 ])
348+ current_concept_values .extend ([- 1 ] + example ["concept_values" ] + [- 1 ])
349+ current_concept_value_masks .extend ([0 ] + example ["concept_value_masks" ] + [0 ])
350+ current_ages .extend ([example ["ages" ][0 ]] + example ["ages" ] + [0 ])
351+ current_dates .extend ([example ["dates" ][0 ]] + example ["dates" ] + [0 ])
352+ current_visit_concept_orders .extend (
353+ [max (0 , example ["visit_concept_orders" ][0 ] - 1 )]
354+ + example ["visit_concept_orders" ]
355+ + [example ["visit_concept_orders" ][- 1 ]]
356+ )
357+ current_visit_segments .extend ([0 ] + example ["visit_segments" ] + [0 ])
358+
359+ if "person_id" in example :
360+ current_person_ids .append (example ["person_id" ])
361+
362+ if "index_date" in example :
363+ current_index_dates .append (example ["index_date" ])
364+
365+ if "age_at_index" in example :
366+ current_age_at_indexes .append (example ["age_at_index" ])
367+
368+ if "classifier_label" in example :
369+ current_labels .append (example ["classifier_label" ])
370+
371+ # The final batch needs to be added
372+ if current_input_ids :
373+ packed_example = {
374+ "input_ids" : current_input_ids ,
375+ "attention_mask" : current_attention_mask ,
376+ "ages" : current_ages ,
377+ "dates" : current_dates ,
378+ "visit_concept_orders" : current_visit_concept_orders ,
379+ "concept_values" : current_concept_values ,
380+ "concept_value_masks" : current_concept_value_masks ,
381+ "visit_segments" : current_visit_segments ,
382+ }
383+
384+ if current_labels :
385+ packed_example .update (
386+ {
387+ "person_id" : current_person_ids ,
388+ "index_date" : current_index_dates ,
389+ "age_at_index" : current_age_at_indexes ,
390+ "classifier_label" : current_labels ,
391+ }
392+ )
393+
394+ flattened_examples .append (packed_example )
395+
396+ return super ().__call__ (flattened_examples )
0 commit comments