Skip to content

Commit e667e0a

Browse files
committed
added compute_cehrbert_features.py for extract features
1 parent 5bb1850 commit e667e0a

File tree

2 files changed

+335
-0
lines changed

2 files changed

+335
-0
lines changed

src/cehrbert/linear_prob/__init__.py

Whitespace-only changes.
Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,335 @@
1+
import glob
2+
import os
3+
import uuid
4+
from datetime import datetime
5+
from functools import partial
6+
from pathlib import Path
7+
8+
import numpy as np
9+
import pandas as pd
10+
import torch
11+
from datasets import DatasetDict, concatenate_datasets, load_from_disk
12+
from torch.utils.data import DataLoader
13+
from tqdm import tqdm
14+
from transformers import TrainingArguments
15+
from transformers.utils import is_flash_attn_2_available, logging
16+
17+
from cehrbert.data_generators.hf_data_generator.cache_util import CacheFileCollector
18+
from cehrbert.data_generators.hf_data_generator.hf_dataset import create_cehrbert_finetuning_dataset
19+
from cehrbert.data_generators.hf_data_generator.hf_dataset_collator import (
20+
CehrBertDataCollator,
21+
SamplePackingCehrBertDataCollator,
22+
)
23+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping
24+
from cehrbert.data_generators.hf_data_generator.meds_utils import CacheFileCollector, create_dataset_from_meds_reader
25+
from cehrbert.data_generators.hf_data_generator.sample_packing_sampler import SamplePackingBatchSampler
26+
from cehrbert.models.hf_models.hf_cehrbert import CehrBert
27+
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
28+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
29+
from cehrbert.runners.runner_util import (
30+
convert_dataset_to_iterable_dataset,
31+
generate_prepared_ds_path,
32+
get_last_hf_checkpoint,
33+
get_meds_extension_path,
34+
load_parquet_as_dataset,
35+
parse_runner_args,
36+
)
37+
38+
LOG = logging.get_logger("transformers")
39+
40+
41+
def prepare_finetune_dataset(
42+
data_args: DataTrainingArguments,
43+
training_args: TrainingArguments,
44+
cache_file_collector: CacheFileCollector,
45+
):
46+
# If the data is in the MEDS format, we need to convert it to the CEHR-BERT format
47+
if data_args.is_data_in_meds:
48+
meds_extension_path = get_meds_extension_path(
49+
data_folder=os.path.expanduser(data_args.cohort_folder),
50+
dataset_prepared_path=os.path.expanduser(data_args.dataset_prepared_path),
51+
)
52+
try:
53+
LOG.info(f"Trying to load the MEDS extension from disk at {meds_extension_path}...")
54+
dataset = load_from_disk(meds_extension_path)
55+
if data_args.streaming:
56+
dataset = convert_dataset_to_iterable_dataset(dataset, num_shards=training_args.dataloader_num_workers)
57+
except Exception as e:
58+
LOG.exception(e)
59+
dataset = create_dataset_from_meds_reader(
60+
data_args,
61+
dataset_mappings=[MedToCehrBertDatasetMapping(data_args=data_args, is_pretraining=False)],
62+
cache_file_collector=cache_file_collector,
63+
)
64+
if not data_args.streaming:
65+
dataset.save_to_disk(str(meds_extension_path))
66+
stats = dataset.cleanup_cache_files()
67+
LOG.info(
68+
"Clean up the cached files for the cehrbert dataset transformed from the MEDS: %s",
69+
stats,
70+
)
71+
# Clean up the files created from the data generator
72+
cache_file_collector.remove_cache_files()
73+
dataset = load_from_disk(str(meds_extension_path))
74+
75+
train_set = dataset["train"]
76+
validation_set = dataset["validation"]
77+
test_set = dataset["test"]
78+
else:
79+
dataset = load_parquet_as_dataset(os.path.expanduser(data_args.data_folder))
80+
test_set = None
81+
if data_args.test_data_folder:
82+
test_set = load_parquet_as_dataset(data_args.test_data_folder)
83+
# Split the dataset into train/val
84+
train_val = dataset.train_test_split(
85+
test_size=data_args.validation_split_percentage,
86+
seed=training_args.seed,
87+
)
88+
train_set = train_val["train"]
89+
validation_set = train_val["test"]
90+
if not test_set:
91+
test_valid = validation_set.train_test_split(test_size=data_args.test_eval_ratio, seed=training_args.seed)
92+
validation_set = test_valid["train"]
93+
test_set = test_valid["test"]
94+
95+
# Organize them into a single DatasetDict
96+
return DatasetDict({"train": train_set, "validation": validation_set, "test": test_set})
97+
98+
99+
def main():
100+
cehrbert_args, data_args, model_args, training_args = parse_runner_args()
101+
if torch.cuda.is_available():
102+
device = torch.device("cuda")
103+
else:
104+
device = torch.device("cpu")
105+
106+
cehrgpt_tokenizer = CehrBertTokenizer.from_pretrained(model_args.tokenizer_name_or_path)
107+
cehrbert_model = (
108+
CehrBert.from_pretrained(
109+
model_args.model_name_or_path,
110+
attn_implementation=("flash_attention_2" if is_flash_attn_2_available() else "eager"),
111+
torch_dtype=(torch.bfloat16 if is_flash_attn_2_available() else torch.float32),
112+
)
113+
.eval()
114+
.to(device)
115+
)
116+
prepared_ds_path = generate_prepared_ds_path(data_args, model_args, data_folder=data_args.cohort_folder)
117+
cache_file_collector = CacheFileCollector()
118+
processed_dataset = None
119+
if any(prepared_ds_path.glob("*")):
120+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
121+
processed_dataset = load_from_disk(str(prepared_ds_path))
122+
LOG.info("Prepared dataset loaded from disk...")
123+
124+
if processed_dataset is None:
125+
# Organize them into a single DatasetDict
126+
final_splits = prepare_finetune_dataset(data_args, training_args, cache_file_collector)
127+
128+
# TODO: temp solution, this column is mixed typed and causes an issue when transforming the data
129+
if not data_args.streaming:
130+
all_columns = final_splits["train"].column_names
131+
if "visit_concept_ids" in all_columns:
132+
final_splits = final_splits.remove_columns(["visit_concept_ids"])
133+
134+
processed_dataset = create_cehrbert_finetuning_dataset(
135+
dataset=final_splits,
136+
concept_tokenizer=cehrgpt_tokenizer,
137+
data_args=data_args,
138+
cache_file_collector=cache_file_collector,
139+
)
140+
if not data_args.streaming:
141+
processed_dataset.save_to_disk(prepared_ds_path)
142+
processed_dataset.cleanup_cache_files()
143+
144+
# Remove all the cached files if processed_dataset.cleanup_cache_files() did not remove them already
145+
cache_file_collector.remove_cache_files()
146+
147+
# Getting the existing features
148+
feature_folders = glob.glob(os.path.join(training_args.output_dir, "*", "features", "*.parquet"))
149+
if feature_folders:
150+
existing_features = pd.concat(
151+
[pd.read_parquet(f, columns=["subject_id", "prediction_time_posix"]) for f in feature_folders],
152+
ignore_index=True,
153+
)
154+
subject_prediction_tuples = set(
155+
existing_features.apply(
156+
lambda row: f"{int(row['subject_id'])}-{int(row['prediction_time_posix'])}",
157+
axis=1,
158+
).tolist()
159+
)
160+
processed_dataset = processed_dataset.filter(
161+
lambda _batch: [
162+
f"{int(subject)}-{int(time)}" not in subject_prediction_tuples
163+
for subject, time in zip(_batch["person_id"], _batch["index_date"])
164+
],
165+
num_proc=data_args.preprocessing_num_workers,
166+
batch_size=data_args.preprocessing_batch_size,
167+
batched=True,
168+
)
169+
LOG.info(
170+
"The datasets after filtering (train: %s, validation: %s, test: %s)",
171+
len(processed_dataset["train"]),
172+
len(processed_dataset["validation"]),
173+
len(processed_dataset["test"]),
174+
)
175+
176+
train_set = concatenate_datasets([processed_dataset["train"], processed_dataset["validation"]])
177+
178+
if cehrbert_args.sample_packing:
179+
per_device_eval_batch_size = 1
180+
data_collator_fn = partial(
181+
SamplePackingCehrBertDataCollator,
182+
cehrbert_args.max_tokens_per_batch,
183+
cehrbert_model.config.max_position_embeddings,
184+
)
185+
train_batch_sampler = SamplePackingBatchSampler(
186+
lengths=train_set["num_of_concepts"],
187+
max_tokens_per_batch=cehrbert_args.max_tokens_per_batch,
188+
max_position_embeddings=cehrbert_model.config.max_position_embeddings,
189+
drop_last=training_args.dataloader_drop_last,
190+
seed=training_args.seed,
191+
)
192+
test_batch_sampler = SamplePackingBatchSampler(
193+
lengths=processed_dataset["test"]["num_of_concepts"],
194+
max_tokens_per_batch=cehrbert_args.max_tokens_per_batch,
195+
max_position_embeddings=cehrbert_model.config.max_position_embeddings,
196+
drop_last=training_args.dataloader_drop_last,
197+
seed=training_args.seed,
198+
)
199+
else:
200+
data_collator_fn = CehrBertDataCollator
201+
train_batch_sampler = None
202+
test_batch_sampler = None
203+
per_device_eval_batch_size = training_args.per_device_eval_batch_size
204+
205+
# We suppress the additional learning objectives in fine-tuning
206+
data_collator = data_collator_fn(
207+
tokenizer=cehrgpt_tokenizer,
208+
max_length=(
209+
cehrbert_args.max_tokens_per_batch
210+
if cehrbert_args.sample_packing
211+
else cehrbert_model.config.max_position_embeddings
212+
),
213+
is_pretraining=False,
214+
)
215+
216+
train_loader = DataLoader(
217+
dataset=train_set,
218+
batch_size=per_device_eval_batch_size,
219+
num_workers=training_args.dataloader_num_workers,
220+
collate_fn=data_collator,
221+
pin_memory=training_args.dataloader_pin_memory,
222+
batch_sampler=train_batch_sampler,
223+
)
224+
225+
test_dataloader = DataLoader(
226+
dataset=processed_dataset["test"],
227+
batch_size=per_device_eval_batch_size,
228+
num_workers=training_args.dataloader_num_workers,
229+
collate_fn=data_collator,
230+
pin_memory=training_args.dataloader_pin_memory,
231+
batch_sampler=test_batch_sampler,
232+
)
233+
234+
# Loading demographics
235+
print("Loading demographics as a dictionary")
236+
demographics_df = pd.concat(
237+
[
238+
pd.read_parquet(
239+
data_dir,
240+
columns=[
241+
"person_id",
242+
"index_date",
243+
"gender_concept_id",
244+
"race_concept_id",
245+
],
246+
)
247+
for data_dir in [data_args.data_folder, data_args.test_data_folder]
248+
]
249+
)
250+
demographics_df["index_date"] = demographics_df.index_date.dt.date
251+
demographics_dict = {
252+
(row["person_id"], row["index_date"]): {
253+
"gender_concept_id": row["gender_concept_id"],
254+
"race_concept_id": row["race_concept_id"],
255+
}
256+
for _, row in demographics_df.iterrows()
257+
}
258+
259+
data_loaders = [("train", train_loader), ("test", test_dataloader)]
260+
261+
for split, data_loader in data_loaders:
262+
263+
# Ensure prediction folder exists
264+
feature_output_folder = Path(training_args.output_dir) / split / "features"
265+
feature_output_folder.mkdir(parents=True, exist_ok=True)
266+
267+
LOG.info("Generating features for %s set at %s", split, feature_output_folder)
268+
269+
with torch.no_grad():
270+
for index, batch in enumerate(tqdm(data_loader, desc="Generating features")):
271+
prediction_time_ages = batch.pop("age_at_index").numpy().astype(float).squeeze()
272+
if prediction_time_ages.ndim == 0:
273+
prediction_time_ages = np.asarray([prediction_time_ages])
274+
275+
person_ids = batch.pop("person_id").numpy().astype(int).squeeze()
276+
if person_ids.ndim == 0:
277+
person_ids = np.asarray([person_ids])
278+
prediction_time_posix = batch.pop("index_date").numpy().squeeze()
279+
if prediction_time_posix.ndim == 0:
280+
prediction_time_posix = np.asarray([prediction_time_posix])
281+
prediction_time = list(map(datetime.fromtimestamp, prediction_time_posix))
282+
labels = batch.pop("classifier_label").float().cpu().numpy().astype(bool).squeeze()
283+
if labels.ndim == 0:
284+
labels = np.asarray([labels])
285+
286+
batch = {k: v.to(device) for k, v in batch.items()}
287+
# Forward pass
288+
cehrgpt_output = cehrbert_model(**batch, output_attentions=False, output_hidden_states=False)
289+
290+
cls_token_indices = batch["input_ids"] == cehrgpt_tokenizer.cls_token_index
291+
if cehrbert_args.sample_packing:
292+
features = (
293+
cehrgpt_output.last_hidden_state[..., cls_token_indices, :]
294+
.cpu()
295+
.float()
296+
.detach()
297+
.numpy()
298+
.squeeze(axis=0)
299+
)
300+
else:
301+
cls_token_index = torch.argmax((cls_token_indices).to(torch.int), dim=-1)
302+
features = cehrgpt_output.last_hidden_state[..., cls_token_index, :].cpu().float().detach().numpy()
303+
assert len(features) == len(labels), "the number of features must match the number of labels"
304+
# Flatten features or handle them as a list of arrays (one array per row)
305+
features_list = [feature for feature in features]
306+
race_concept_ids = []
307+
gender_concept_ids = []
308+
for person_id, index_date in zip(person_ids, prediction_time):
309+
key = (person_id, index_date.date())
310+
if key in demographics_dict:
311+
demographics = demographics_dict[key]
312+
gender_concept_ids.append(demographics["gender_concept_id"])
313+
race_concept_ids.append(demographics["race_concept_id"])
314+
else:
315+
gender_concept_ids.append(0)
316+
race_concept_ids.append(0)
317+
318+
features_pd = pd.DataFrame(
319+
{
320+
"subject_id": person_ids,
321+
"prediction_time": prediction_time,
322+
"prediction_time_posix": prediction_time_posix,
323+
"boolean_value": labels,
324+
"age_at_index": prediction_time_ages,
325+
}
326+
)
327+
# Adding features as a separate column where each row contains a feature array
328+
features_pd["features"] = features_list
329+
features_pd["race_concept_id"] = race_concept_ids
330+
features_pd["gender_concept_id"] = gender_concept_ids
331+
features_pd.to_parquet(feature_output_folder / f"{uuid.uuid4()}.parquet")
332+
333+
334+
if __name__ == "__main__":
335+
main()

0 commit comments

Comments
 (0)