Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 77 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# CEHR-BERT

[![PyPI - Version](https://img.shields.io/pypi/v/cehrbert)](https://pypi.org/project/cehrbert/)
![Python](https://img.shields.io/badge/-Python_3.10-blue?logo=python&logoColor=white)
[![tests](https://github.com/cumc-dbmi/cehrbert/actions/workflows/tests.yml/badge.svg)](https://github.com/cumc-dbmi/cehrbert/actions/workflows/tests.yml)
[![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](https://github.com/cumc-dbmi/cehrbert/blob/main/LICENSE)
[![contributors](https://img.shields.io/github/contributors/cumc-dbmi/cehrbert.svg)](https://github.com/cumc-dbmi/cehrbert/graphs/contributors)


CEHR-BERT is a large language model developed for the structured EHR data, the work has been published
at https://proceedings.mlr.press/v158/pang21a.html. CEHR-BERT currently only supports the structured EHR data in the
OMOP format, which is a common data model used to support observational studies and managed by the Observational Health
Expand Down Expand Up @@ -55,15 +62,9 @@ Build the project
pip install -e .[dev]
```

Download [jtds-1.3.1.jar](jtds-1.3.1.jar) into the spark jars folder in the python environment
```console
cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/
```

## Instructions for Use with [MEDS](https://github.com/Medical-Event-Data-Standard/meds)

### 1. Convert MEDS to the [meds_reader](https://github.com/som-shahlab/meds_reader) database

Step 1. Convert MEDS to the [meds_reader](https://github.com/som-shahlab/meds_reader) database
---------------------------
If you don't have the MEDS dataset, you could convert the OMOP dataset to the MEDS
using [meds_etl](https://github.com/Medical-Event-Data-Standard/meds_etl).
We have prepared a synthea dataset with 1M patients for you to test, you could download it
Expand Down Expand Up @@ -123,22 +124,41 @@ Convert MEDS to the meds_reader database to get the patient level data
meds_reader_convert synthea_meds synthea_meds_reader --num_threads 4
```

### 2. Pretrain CEHR-BERT using the meds_reader database
Step 2. Pretrain CEHR-BERT using the meds_reader database
---------------------------
```console
mkdir test_dataset_prepared;
mkdir test_synthea_results;
python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml
python -m cehrbert.runners.hf_cehrbert_pretrain_runner \
sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml
```

## Instructions for Use with OMOP

### 1. Download OMOP tables as parquet files

Step 1. Download OMOP tables as parquet files
---------------------------
We created a spark app to download OMOP tables from SQL Server as parquet files. You need adjust the properties
in `db_properties.ini` to match with your database setup.

in `db_properties.ini` to match with your database setup. Download [jtds-1.3.1.jar](https://mvnrepository.com/artifact/net.sourceforge.jtds/jtds/1.3.1) into the spark jars folder in the python environment.
```console
PYTHONPATH=./: spark-submit tools/download_omop_tables.py -c db_properties.ini -tc person visit_occurrence condition_occurrence procedure_occurrence drug_exposure measurement observation_period concept concept_relationship concept_ancestor -o ~/Documents/omop_test/
cp jtds-1.3.1.jar .venv/lib/python3.10/site-packages/pyspark/jars/
```
We use spark as the data processing engine to generate the pretraining data.
For that, we need to set up the relevant SPARK environment variables.
```bash
# the omop derived tables need to be built using pyspark
export SPARK_WORKER_INSTANCES="1"
export SPARK_WORKER_CORES="16"
export SPARK_EXECUTOR_CORES="2"
export SPARK_DRIVER_MEMORY="12g"
export SPARK_EXECUTOR_MEMORY="12g"
```
Download the OMOP tables as parquet files
```console
python -u -m cehrbert.tools.download_omop_tables -c db_properties.ini \
-tc person visit_occurrence condition_occurrence procedure_occurrence \
drug_exposure measurement observation_period \
concept concept_relationship concept_ancestor \
-o ~/Documents/omop_test/
```

We have prepared a synthea dataset with 1M patients for you to test, you could download it
Expand All @@ -148,44 +168,73 @@ at [omop_synthea.tar.gz](https://drive.google.com/file/d/1k7-cZACaDNw8A1JRI37mfM
tar -xvf omop_synthea.tar ~/Document/omop_test/
```

### 2. Generate training data for CEHR-BERT

Step 2. Generate training data for CEHR-BERT using cehrbert_data
---------------------------
We order the patient events in chronological order and put all data points in a sequence. We insert artificial tokens
VS (visit start) and VE (visit end) to the start and the end of the visit. In addition, we insert artificial time
tokens (ATT) between visits to indicate the time interval between visits. This approach allows us to apply BERT to
structured EHR as-is.
The sequence can be seen conceptually as [VS] [V1] [VE] [ATT] [VS] [V2] [VE], where [V1] and [V2] represent a list of
concepts associated with those visits.

```console
PYTHONPATH=./: spark-submit spark_apps/generate_training_data.py -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -tc condition_occurrence procedure_occurrence drug_exposure -d 1985-01-01 --is_new_patient_representation -iv
Set up the pyspark environment variables if you haven't done so.
```bash
# the omop derived tables need to be built using pyspark
export SPARK_WORKER_INSTANCES="1"
export SPARK_WORKER_CORES="16"
export SPARK_EXECUTOR_CORES="2"
export SPARK_DRIVER_MEMORY="12g"
export SPARK_EXECUTOR_MEMORY="12g"
```
Generate the pretraining data using the following command
```bash
sh src/cehrbert/scripts/create_cehrbert_pretraining_data.sh \
--input_folder $OMOP_DIR \
--output_folde $CEHR_BERT_DATA_DIR \
--start_date "1985-01-01"
```

### 3. Pre-train CEHR-BERT
Step 3. Pre-train CEHR-BERT
---------------------------
If you don't have your own OMOP instance, we have provided a sample of patient sequence data generated using Synthea
at `sample/patient_sequence` in the repo. CEHR-BERT expects the data folder to be named as `patient_sequence`

```console
mkdir test_dataset_prepared;
mkdir test_results;
python -m cehrbert.runners.hf_cehrbert_pretrain_runner sample_configs/hf_cehrbert_pretrain_runner_config.yaml
python -m cehrbert.runners.hf_cehrbert_pretrain_runner \
sample_configs/hf_cehrbert_pretrain_runner_config.yaml
```

If your dataset is large, you could add ```--use_dask``` in the command above

### 4. Generate hf readmission prediction task
Step 4. Generate hf readmission prediction task
---------------------------
If you don't have your own OMOP instance, we have provided a sample of patient sequence data generated using Synthea
at `sample/hf_readmissioon` in the repo

at `sample/hf_readmissioon` in the repo. Set up the pyspark environment variables if you haven't done so.
```bash
# the omop derived tables need to be built using pyspark
export SPARK_WORKER_INSTANCES="1"
export SPARK_WORKER_CORES="16"
export SPARK_EXECUTOR_CORES="2"
export SPARK_DRIVER_MEMORY="12g"
export SPARK_EXECUTOR_MEMORY="12g"
```
Generate the HF readmission prediction task
```console
PYTHONPATH=./:$PYTHONPATH spark-submit spark_apps/prediction_cohorts/hf_readmission.py -c hf_readmission -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert -dl 1985-01-01 -du 2020-12-31 -l 18 -u 100 -ow 360 -ps 0 -pw 30 --is_new_patient_representation
python -u -m cehrbert.prediction_cohorts.hf_readmission \
-c hf_readmission -i ~/Documents/omop_test/ -o ~/Documents/omop_test/cehr-bert \
-dl 1985-01-01 -du 2020-12-31 \
-l 18 -u 100 -ow 360 -ps 0 -pw 30 \
--is_new_patient_representation
```

### 5. Fine-tune CEHR-BERT

Step 5. Fine-tune CEHR-BERT
---------------------------
```console
mkdir test_finetune_results;
python -m cehrbert.runners.hf_cehrbert_finetune_runner sample_configs/hf_cehrbert_finetuning_runner_config.yaml
python -m cehrbert.runners.hf_cehrbert_finetune_runner \
sample_configs/hf_cehrbert_finetuning_runner_config.yaml
```

## Contact us
Expand Down
3 changes: 3 additions & 0 deletions sample_configs/hf_cehrbert_finetuning_runner_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Please point this to your model folder
model_name_or_path: "test_results"
# Please point this to your model folder
tokenizer_name_or_path: "test_results"

data_folder: "sample_data/finetune/full"
Expand Down Expand Up @@ -32,6 +34,7 @@ max_position_embeddings: 512
dataloader_num_workers: 4
dataloader_prefetch_factor: 2

# Please point this to your finetuned model folder
output_dir: "test_finetune_results"
evaluation_strategy: "epoch"
save_strategy: "epoch"
Expand Down
4 changes: 4 additions & 0 deletions sample_configs/hf_cehrbert_pretrain_runner_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Please point this to your output model folder
model_name_or_path: "test_results"
# Please point this to your output model folder
tokenizer_name_or_path: "test_results"

data_folder: "sample_data/pretrain"
Expand Down Expand Up @@ -32,7 +34,9 @@ max_position_embeddings: 512
dataloader_num_workers: 4
dataloader_prefetch_factor: 4

# Please point this to your output model folder
output_dir: "test_results"

evaluation_strategy: "epoch"
save_strategy: "epoch"
learning_rate: 0.00005
Expand Down
4 changes: 4 additions & 0 deletions sample_configs/hf_cehrbert_pretrain_runner_meds_config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Please point this to your model folder
model_name_or_path: "test_synthea_results"
# Please point this to your model folder
tokenizer_name_or_path: "test_synthea_results"

# Please point this to the MEDS_READER because the MEDS data is used as the input
data_folder: "synthea_meds_reader"
dataset_prepared_path: "test_dataset_prepared"
validation_split_percentage: 0.05
Expand Down Expand Up @@ -32,6 +35,7 @@ max_position_embeddings: 512
dataloader_num_workers: 4
dataloader_prefetch_factor: 4

# Please point this to your model folder
output_dir: "test_synthea_results"
evaluation_strategy: "epoch"
save_strategy: "epoch"
Expand Down
103 changes: 103 additions & 0 deletions src/cehrbert/scripts/create_cehrbert_pretraining_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/bin/bash

# Function to display usage
usage() {
echo "Usage: $0 --input_folder INPUT_FOLDER --output_folder OUTPUT_FOLDER --start_date START_DATE"
echo ""
echo "Required Arguments:"
echo " --input_folder PATH Input folder path"
echo " --output_folder PATH Output folder path"
echo " --start_date DATE Start date"
echo ""
echo "Example:"
echo " $0 --input_folder /path/to/input --output_folder /path/to/output --start_date 1985-01-01"
exit 1
}

# Check if no arguments were provided
if [ $# -eq 0 ]; then
usage
fi

# Initialize variables
INPUT_FOLDER=""
OUTPUT_FOLDER=""
START_DATE=""

# Domain tables (fixed list)
DOMAIN_TABLES=("condition_occurrence" "procedure_occurrence" "drug_exposure")

# Parse command line arguments
ARGS=$(getopt -o "" --long input_folder:,output_folder:,start_date:,help -n "$0" -- "$@")

if [ $? -ne 0 ]; then
usage
fi

eval set -- "$ARGS"

while true; do
case "$1" in
--input_folder)
INPUT_FOLDER="$2"
shift 2
;;
--output_folder)
OUTPUT_FOLDER="$2"
shift 2
;;
--start_date)
START_DATE="$2"
shift 2
;;
--help)
usage
;;
--)
shift
break
;;
*)
echo "Internal error!"
exit 1
;;
esac
done

# Validate required arguments
if [ -z "$INPUT_FOLDER" ] || [ -z "$OUTPUT_FOLDER" ] || [ -z "$START_DATE" ]; then
echo "Error: Missing required arguments"
usage
fi

# Create output folder if it doesn't exist
mkdir -p "$OUTPUT_FOLDER"

# Step 1: Generate included concept list
CONCEPT_LIST_CMD="python -u -m cehrbert_data.apps.generate_included_concept_list \
-i \"$INPUT_FOLDER\" \
-o \"$OUTPUT_FOLDER\" \
--min_num_of_patients 100 \
--ehr_table_list ${DOMAIN_TABLES[@]}"

echo "Running concept list generation:"
echo "$CONCEPT_LIST_CMD"
eval "$CONCEPT_LIST_CMD"

# Step 2: Generate training data
TRAINING_DATA_CMD="python -m cehrbert_data.apps.generate_training_data \
--input_folder \"$INPUT_FOLDER\" \
--output_folder \"$OUTPUT_FOLDER\" \
-d $START_DATE \
--att_type day \
--inpatient_att_type day \
-iv \
-ip \
--include_concept_list \
--include_death \
--gpt_patient_sequence \
--domain_table_list ${DOMAIN_TABLES[@]}"

echo "Running training data generation:"
echo "$TRAINING_DATA_CMD"
eval "$TRAINING_DATA_CMD"