Skip to content

Commit f2df8ce

Browse files
authored
implemented sample packing for cehrbert (#104)
* implemented sample packing for cehrbert * clamped concept_values between -10 and 10; fixed a bug when generating time_embeddings * added compute_cehrbert_features.py for extract features * removed the sort patient sequence mapping * updated sample_packing_sampler.py * updated the index_date and age_at_index data types in hf_dataset_collator.py * Synchronized bert flash attention layer with the eager implementation * removed torch_dtype when loading the model * put back upad_input in src/cehrbert/models/hf_models/hf_cehrbert.py * removed squeeze in compute_cehrbert_features * switched to CehrBertForPreTraining for computing features * added an option to get the features by averaging over the entire sequence for each sample * added train_with_cehrbert_features * added device to start_indices and end_indices * swtiched to a vectorized implementation * corrected the logic for extract CLS embeddings in sample packing * implemented sample packing for cehrbert sample packing * fixed the integration test * adjusted finetuning model for sample packing * do not use sample packing for running predicitons * updated transformers version * updated the integration tests * fixed a data loading bug in streaming
1 parent 27c0e4f commit f2df8ce

18 files changed

+1502
-535
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ dependencies = [
5252
"tqdm>=4.66.1",
5353
"torch==2.4.0",
5454
"tokenizers>=0.19.0",
55-
"transformers>=4.41.0",
55+
"transformers>=4.41.0, <= 4.45.0",
5656
"accelerate>=0.31.0",
5757
"Werkzeug==3.0.1",
5858
"wandb>=0.17.8",

src/cehrbert/data_generators/hf_data_generator/hf_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def create_cehrbert_pretraining_dataset(
4848
mapping_functions = [HFTokenizationMapping(concept_tokenizer, True)]
4949
else:
5050
mapping_functions = [
51-
SortPatientSequenceMapping(),
51+
# SortPatientSequenceMapping(),
5252
HFTokenizationMapping(concept_tokenizer, True),
5353
]
5454

@@ -91,7 +91,7 @@ def create_cehrbert_finetuning_dataset(
9191
else:
9292
mapping_functions = [
9393
HFFineTuningMapping(),
94-
SortPatientSequenceMapping(),
94+
# SortPatientSequenceMapping(),
9595
HFTokenizationMapping(concept_tokenizer, False),
9696
]
9797

@@ -166,6 +166,8 @@ def apply_cehrbert_dataset_mapping(
166166
)
167167
if mapping_function.remove_columns():
168168
dataset = dataset.remove_columns(mapping_function.remove_columns())
169-
if cache_file_collector:
170-
cache_file_collector.add_cache_files(dataset)
169+
170+
if cache_file_collector:
171+
cache_file_collector.add_cache_files(dataset)
172+
171173
return dataset

src/cehrbert/data_generators/hf_data_generator/hf_dataset_collator.py

Lines changed: 214 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import collections
22
import random
3-
from typing import Any, Dict, Tuple
3+
from typing import Any, Dict, Optional, Tuple
44

55
import numpy as np
66
import torch
@@ -52,7 +52,12 @@ def __call__(self, examples):
5252
# Assume that each example in the batch is a dictionary with 'input_ids' and 'attention_mask'
5353
batch_input_ids = [self._convert_to_tensor(example["input_ids"]) for example in examples]
5454
batch_attention_mask = [
55-
torch.ones_like(self._convert_to_tensor(example["input_ids"]), dtype=torch.float) for example in examples
55+
(
56+
self._convert_to_tensor(example["attention_mask"]).to(torch.float)
57+
if "attention_mask" in example
58+
else torch.ones_like(self._convert_to_tensor(example["input_ids"]), dtype=torch.float)
59+
)
60+
for example in examples
5661
]
5762
batch_ages = [self._convert_to_tensor(example["ages"]) for example in examples]
5863
batch_dates = [self._convert_to_tensor(example["dates"]) for example in examples]
@@ -79,35 +84,42 @@ def __call__(self, examples):
7984
batch["concept_value_masks"] = pad_sequence(batch_concept_value_masks, batch_first=True, padding_value=0.0)
8085
batch["visit_segments"] = pad_sequence(batch_visit_segments, batch_first=True, padding_value=0)
8186

82-
# Prepend the CLS token and their associated values to the corresponding time series features
83-
batch["input_ids"] = torch.cat(
84-
[
85-
torch.full((batch_size, 1), self.tokenizer.cls_token_index),
86-
batch["input_ids"],
87-
],
88-
dim=1,
89-
)
90-
# The attention_mask is set to 1 to enable attention for the CLS token
91-
batch["attention_mask"] = torch.cat([torch.full((batch_size, 1), 1.0), batch["attention_mask"]], dim=1)
92-
# Set the age of the CLS token to the starting age
93-
batch["ages"] = torch.cat([batch["ages"][:, 0:1], batch["ages"]], dim=1)
94-
# Set the age of the CLS token to the starting date
95-
batch["dates"] = torch.cat([batch["dates"][:, 0:1], batch["dates"]], dim=1)
96-
# Set the visit_concept_order of the CLS token to the first visit_concept_order in the sequence subtract by 1
97-
visit_concept_orders_first = batch["visit_concept_orders"][:, 0:1] - 1
98-
visit_concept_orders_first = torch.maximum(
99-
visit_concept_orders_first, torch.zeros_like(visit_concept_orders_first)
100-
)
101-
batch["visit_concept_orders"] = torch.cat([visit_concept_orders_first, batch["visit_concept_orders"]], dim=1)
102-
# Set the concept_value of the CLS token to a default value -1.0.
103-
batch["concept_values"] = torch.cat([torch.full((batch_size, 1), -1.0), batch["concept_values"]], dim=1)
104-
# Set the concept_value of the CLS token to a default value 0.0 indicating that
105-
# there is no value associated with this token
106-
batch["concept_value_masks"] = torch.cat(
107-
[torch.full((batch_size, 1), 0.0), batch["concept_value_masks"]], dim=1
108-
)
109-
# Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit
110-
batch["visit_segments"] = torch.cat([torch.full((batch_size, 1), 0), batch["visit_segments"]], dim=1)
87+
if not getattr(self, "sample_packing", False):
88+
# Prepend the CLS token and their associated values to the corresponding time series features
89+
batch["input_ids"] = torch.cat(
90+
[
91+
torch.full((batch_size, 1), self.tokenizer.cls_token_index),
92+
batch["input_ids"],
93+
],
94+
dim=1,
95+
)
96+
# The attention_mask is set to 1 to enable attention for the CLS token
97+
batch["attention_mask"] = torch.cat([torch.full((batch_size, 1), 1.0), batch["attention_mask"]], dim=1)
98+
# Set the age of the CLS token to the starting age
99+
batch["ages"] = torch.cat([batch["ages"][:, 0:1], batch["ages"]], dim=1)
100+
# Set the age of the CLS token to the starting date
101+
batch["dates"] = torch.cat([batch["dates"][:, 0:1], batch["dates"]], dim=1)
102+
# Set the visit_concept_order of the CLS token to the first visit_concept_order in the sequence subtract by 1
103+
visit_concept_orders_first = batch["visit_concept_orders"][:, 0:1] - 1
104+
visit_concept_orders_first = torch.maximum(
105+
visit_concept_orders_first, torch.zeros_like(visit_concept_orders_first)
106+
)
107+
batch["visit_concept_orders"] = torch.cat(
108+
[visit_concept_orders_first, batch["visit_concept_orders"]], dim=1
109+
)
110+
# Set the concept_value of the CLS token to a default value -1.0.
111+
batch["concept_values"] = torch.cat([torch.full((batch_size, 1), -1.0), batch["concept_values"]], dim=1)
112+
# Set the concept_value of the CLS token to a default value 0.0 indicating that
113+
# there is no value associated with this token
114+
batch["concept_value_masks"] = torch.cat(
115+
[torch.full((batch_size, 1), 0.0), batch["concept_value_masks"]], dim=1
116+
)
117+
# Set the visit_segments of the CLS token to a default value 0 because this doesn't belong to a visit
118+
batch["visit_segments"] = torch.cat([torch.full((batch_size, 1), 0), batch["visit_segments"]], dim=1)
119+
else:
120+
assert (
121+
batch["attention_mask"].shape[0] == 1
122+
), f"batch['attention_mask'].shape[0] must be 0 in sample packing"
111123

112124
# This is the most crucial logic for generating the training labels
113125
if self.is_pretraining:
@@ -125,29 +137,46 @@ def __call__(self, examples):
125137

126138
batch["input_ids"], batch["labels"] = self.torch_mask_tokens(batch["input_ids"], batch["labels"])
127139

140+
bz = len(examples)
128141
if "person_id" in examples[0]:
129-
batch["person_id"] = torch.cat(
130-
[self._convert_to_tensor(example["person_id"]).reshape(-1, 1) for example in examples],
131-
dim=0,
132-
).to(torch.float)
142+
batch["person_id"] = (
143+
torch.cat(
144+
[self._convert_to_tensor(example["person_id"]).reshape(-1, 1) for example in examples],
145+
dim=0,
146+
)
147+
.to(torch.int32)
148+
.reshape(bz, -1)
149+
)
133150

134151
if "index_date" in examples[0]:
135-
batch["index_date"] = torch.cat(
136-
[self._convert_to_tensor(example["index_date"]).reshape(-1, 1) for example in examples],
137-
dim=0,
138-
).to(torch.float32)
152+
batch["index_date"] = (
153+
torch.cat(
154+
[self._convert_to_tensor(example["index_date"]).reshape(-1, 1) for example in examples],
155+
dim=0,
156+
)
157+
.to(torch.float64)
158+
.reshape(bz, -1)
159+
)
139160

140161
if "age_at_index" in examples[0]:
141-
batch["age_at_index"] = torch.cat(
142-
[self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1) for example in examples],
143-
dim=0,
144-
).to(torch.float)
162+
batch["age_at_index"] = (
163+
torch.cat(
164+
[self._convert_to_tensor(example["age_at_index"]).reshape(-1, 1) for example in examples],
165+
dim=0,
166+
)
167+
.to(torch.float32)
168+
.reshape(bz, -1)
169+
)
145170

146171
if "classifier_label" in examples[0]:
147-
batch["classifier_label"] = torch.cat(
148-
[self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1) for example in examples],
149-
dim=0,
150-
).to(torch.float)
172+
batch["classifier_label"] = (
173+
torch.cat(
174+
[self._convert_to_tensor(example["classifier_label"]).reshape(-1, 1) for example in examples],
175+
dim=0,
176+
)
177+
.to(torch.float)
178+
.reshape(bz, -1)
179+
)
151180

152181
return batch
153182

@@ -173,14 +202,19 @@ def torch_mask_tokens(self, inputs: torch.Tensor, labels: torch.Tensor) -> Tuple
173202
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
174203
return inputs, labels
175204

176-
def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
205+
def generate_start_end_index(
206+
self, record: Dict[str, Any], max_length_allowed: Optional[int] = None
207+
) -> Dict[str, Any]:
177208
"""
178209
Adapted from https://github.com/OHDSI/Apollo/blob/main/data_loading/data_transformer.py.
179210
180211
Adding the start and end indices to extract a portion of the patient sequence
181212
"""
213+
sample_packing = getattr(self, "sample_packing", False)
214+
max_length_allowed = self.max_length if max_length_allowed is None else max_length_allowed
182215
seq_length = len(record["input_ids"])
183-
new_max_length = self.max_length - 1 # Subtract one for the [CLS] token
216+
# Subtract one for the [CLS] token
217+
new_max_length = max_length_allowed if sample_packing else max_length_allowed - 1
184218

185219
# Return the record directly if the actual sequence length is less than the max sequence
186220
if seq_length <= new_max_length:
@@ -229,3 +263,134 @@ def generate_start_end_index(self, record: Dict[str, Any]) -> Dict[str, Any]:
229263
else:
230264
new_record[k] = v
231265
return new_record
266+
267+
268+
class SamplePackingCehrBertDataCollator(CehrBertDataCollator):
269+
def __init__(self, max_tokens, max_position_embeddings, *args, **kwargs):
270+
self.max_tokens_per_batch = max_tokens
271+
self.max_position_embeddings = max_position_embeddings
272+
self.sample_packing = True
273+
super(SamplePackingCehrBertDataCollator, self).__init__(*args, **kwargs)
274+
275+
def __call__(self, examples):
276+
flattened_examples = []
277+
278+
# Main inputs
279+
current_input_ids = []
280+
current_attention_mask = []
281+
current_concept_values = []
282+
current_concept_value_masks = []
283+
current_ages = []
284+
current_dates = []
285+
current_visit_concept_orders = []
286+
current_visit_segments = []
287+
288+
# Demographics
289+
current_person_ids = []
290+
current_index_dates = []
291+
292+
# Binary classification inputs
293+
current_age_at_indexes = []
294+
current_labels = []
295+
296+
for idx, example in enumerate(examples):
297+
# If the sample length exceeds the model's capacity, truncate this example
298+
if len(example["input_ids"]) > self.max_position_embeddings:
299+
example = self.generate_start_end_index(example, self.max_position_embeddings)
300+
301+
input_ids = example["input_ids"]
302+
# We add the flattened example to the list either when the example exceeds the total max tokens
303+
# we add the length by two because we need to add two more tokens [CLS] .... [PAD]
304+
if len(current_input_ids) + len(input_ids) + 2 > self.max_tokens_per_batch and current_input_ids:
305+
packed_example = {
306+
"input_ids": current_input_ids,
307+
"attention_mask": current_attention_mask,
308+
"ages": current_ages,
309+
"dates": current_dates,
310+
"visit_concept_orders": current_visit_concept_orders,
311+
"concept_values": current_concept_values,
312+
"concept_value_masks": current_concept_value_masks,
313+
"visit_segments": current_visit_segments,
314+
}
315+
316+
if current_labels:
317+
packed_example.update(
318+
{
319+
"person_id": current_person_ids,
320+
"index_date": current_index_dates,
321+
"age_at_index": current_age_at_indexes,
322+
"classifier_label": current_labels,
323+
}
324+
)
325+
326+
flattened_examples.append(packed_example)
327+
328+
# Main inputs
329+
current_input_ids = []
330+
current_attention_mask = []
331+
current_concept_values = []
332+
current_concept_value_masks = []
333+
current_ages = []
334+
current_dates = []
335+
current_visit_concept_orders = []
336+
current_visit_segments = []
337+
338+
# Demographics
339+
current_person_ids = []
340+
current_index_dates = []
341+
342+
# Binary classification inputs
343+
current_age_at_indexes = []
344+
current_labels = []
345+
346+
current_input_ids.extend([self.tokenizer.cls_token_index] + input_ids + [self.tokenizer.pad_token_index])
347+
current_attention_mask.extend([1] + np.ones_like(input_ids).tolist() + [0])
348+
current_concept_values.extend([-1] + example["concept_values"] + [-1])
349+
current_concept_value_masks.extend([0] + example["concept_value_masks"] + [0])
350+
current_ages.extend([example["ages"][0]] + example["ages"] + [0])
351+
current_dates.extend([example["dates"][0]] + example["dates"] + [0])
352+
current_visit_concept_orders.extend(
353+
[max(0, example["visit_concept_orders"][0] - 1)]
354+
+ example["visit_concept_orders"]
355+
+ [example["visit_concept_orders"][-1]]
356+
)
357+
current_visit_segments.extend([0] + example["visit_segments"] + [0])
358+
359+
if "person_id" in example:
360+
current_person_ids.append(example["person_id"])
361+
362+
if "index_date" in example:
363+
current_index_dates.append(example["index_date"])
364+
365+
if "age_at_index" in example:
366+
current_age_at_indexes.append(example["age_at_index"])
367+
368+
if "classifier_label" in example:
369+
current_labels.append(example["classifier_label"])
370+
371+
# The final batch needs to be added
372+
if current_input_ids:
373+
packed_example = {
374+
"input_ids": current_input_ids,
375+
"attention_mask": current_attention_mask,
376+
"ages": current_ages,
377+
"dates": current_dates,
378+
"visit_concept_orders": current_visit_concept_orders,
379+
"concept_values": current_concept_values,
380+
"concept_value_masks": current_concept_value_masks,
381+
"visit_segments": current_visit_segments,
382+
}
383+
384+
if current_labels:
385+
packed_example.update(
386+
{
387+
"person_id": current_person_ids,
388+
"index_date": current_index_dates,
389+
"age_at_index": current_age_at_indexes,
390+
"classifier_label": current_labels,
391+
}
392+
)
393+
394+
flattened_examples.append(packed_example)
395+
396+
return super().__call__(flattened_examples)

0 commit comments

Comments
 (0)