This repository is the official implementation of the paper LLM-CXR: Instruction-Finetuned LLM for CXR Image Understanding and Generation (arxiv).
For more generation examples, see the paper on arxiv.
We use conda to manage the environment. Create a new environment using the environment.yaml file.
conda env create --file environment.yamlTo activate the conda environment, run the command below.
conda activate llm-cxrWe provide checkpoints used in the paper to generate and evaluate results. Using this checkpoint, you can interactively experience and reproduce LLM-CXR in the Gradio environment without training. See the paper for details regarding training of uploaded checkpoints.
| Model | Link |
|---|---|
| LLM | link |
| VQ-GAN | link |
Unzip the downloaded llmcxr_mimic-cxr-256-txvloss-medvqa-stage1_2.tar file. Place the unzipped llmcxr_mimic-cxr-256-txvloss-medvqa-stage1_2 directory in the ckpts/ directory.
Unzip the downloaded vqgan_mimic-cxr-256-txvloss.tar file. Place the unzipped vqgan_mimic-cxr-256-txvloss directory in the ckpts/ directory.
Run the shell script below to run the interactive demo server of LLM-CXR. We recommend using a GPU with at least 11GB of memory such as NVIDIA GeForce GTX 1080 Ti 11GB or higher.
python generate_interactive.py <model_path> <vqgan_config_path> <vqgan_ckpt_path>python generate_interactive.py \
ckpts/llmcxr_mimic-cxr-256-txvloss-medvqa-stage1_2 \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-project-compat.yaml \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-4e-compat.ckptYou can access the demo server at http://localhost:7860/ in your browser.
To reproduce results, separate environments must be used for LLM training and inference, and for image encoding and decoding. Install llm-cxr and llm-cxr-taming conda virtual environment using the script below.
conda env create --file environment.yaml # install llm-cxr environment
conda env create --file environment_taming.yaml # install llm-cxr-taming environmentNOTE: to access the
MIMIC-CXRdataset family, appropriate credentials are required.
-
MIMIC-CXR dataset: Download the entire
MIMIC-CXR-JPGdataset from the MIMIC-CXR-JPG dataset. All downloaded files must be located underdata/mimic-cxr-jpg. Unzip the metadata files in the.csv.gzformat at the root of thedata/mimic-cxr-jpgdirectory. Then downloadmimic-cxr-reports.zipfile from the MIMIC-CXR dataset, unzip it, and place thefiles/directory into thedata/mimic-cxr-jpg/reportsdirectory. -
Instruction following dataset: Download
databricks-dolly-15k.jsonlfile from here and put it indata/directory.
Your final directory structure should look like this:
data/
├── mimic-cxr-jpg/
│ ├── files/
│ │ ├── p10/
│ │ ├── ...
│ │ └── p19/
│ ├── reports/
│ │ └── files/
│ │ ├── p10/
│ │ ├── ...
│ │ └── p19/
│ ├── mimic-cxr-jpg_medvqa_v1/
│ │ ├── p10/
│ │ ├── ...
│ │ └── p19/
│ ├── mimic-cxr-2.0.0-metadata
│ ├── mimic-cxr-2.0.0-split
│ ├── mimic-cxr-2.0.0-selected-pa-ap-earlist-study.pickle
│ └── mimic-cxr-2.0.0-selected-pa-ap.pickle
├── databricks-dolly-15k.jsonl
├── eval_dicom_ids.pickle
├── mimic_cxr_img_list_train.txt
├── mimic_cxr_img_list_test.txt
└── mimic-cxr-256-txvloss_codebook_indices.pickle # see below to generate this file
Unfortunately, the dependencies of the VQ-GAN library are too old and are not compatible with the LLM-CXR environment. Therefore, the code to train VQ-GAN is not ready now. Instead, use the checkpoints of VQ-GAN that we trained in advance.
To encode the entire MIMIC-CXR images with VQ-GAN, run the shell script below. This will create a mimic-cxr-256-txvloss_codebook_indices.pickle file in the data/ directory. This file contains the encoded (vector quantized) entire CXR images.
conda activate llm-cxr-taming
python encode_cxr_all.py <vqgan_config_path> <vqgan_ckpt_path> <path_result> <paths_data_list1> <paths_data_list2> ...conda activate llm-cxr-taming
python encode_cxr_all.py \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-project-compat.yaml \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-4e-compat.ckpt \
data/mimic-cxr-256-txvloss_codebook_indices.pickle \
data/mimic_cxr_img_list_train.txt \
data/mimic_cxr_img_list_test.txtRun the shell script below for the stage 1 train.
conda activate llm-cxr
./train_llmcxr_stage1.shRun the shell script below for the stage 2 train.
conda activate llm-cxr
./tarin_llmcxr_stage2.shBefore running, modify the environment variable input_model in the train_llmcxr_stage2.sh file to point to the checkpoint path of the model trained in stage1.
The checkpoint of the saved LLM is a DeepSpeed zero checkpoint, thus, must be converted to the pytorch_model.bin file for inference or to continue training. Convert the checkpoint using the zero_to_fp32.py file created together in the created checkpoint directory. You can simply convert using the script below.
conda activate llm-cxr
python zero_to_fp32.py . pytorch_model.binThe current settings are geared towards NVIDIA A100 40GB x8, but if you change the DeepSpeed settings, you can also train on smaller GPUs like NVIDIA GeForceRTX 3090 24GB x2. Please refer to the original code base repository and change the DeepSpeed settings (config/ds_z3_bf16_config.json) accordingly. Also adjust the --num_gpus argument in the train_llmcxr_stage*.sh file to match the number of GPUs.
To generate inference results for evaluation, run the shell script below. This will create a eval_inference/ directory in the root directory. This directory contains the inference results dolly__eval_results_0_1.pickle. This file contains the inference results of the report-to-CXR and CXR-to-report tasks from our evaluation dataset data/eval_dicom_ids.pickle.
conda activate llm-cxr
python generate_eval.py <model_path> <cxr_vq_path> <output_root> conda activate llm-cxr
python generate_eval.py ckpts/llmcxr_mimic-cxr-256-txvloss-medvqa-stage1_2 data/mimic-cxr-256-txvloss_codebook_indices.pickle eval_inferenceTo decode the inference results, run the shell script below. This will create a eval_inference_decoded/ directory in the root directory. The generated_imgs_jpg/ directory contains images generated from reports, and the generated_reports.txt file contains reports generated from images. GT reports and generated reports are interleaved in order.
conda activate llm-cxr-taming
python decode_cxr_all.py <vqgan_config_path> <vqgan_ckpt_path> <output_root> <infer_result_path>conda activate llm-cxr-taming
python decode_cxr_all.py \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-project-compat.yaml \
ckpts/vqgan_mimic-cxr-256-txvloss/2023-09-05T13-56-50_mimic-cxr-256-txvloss-4e-compat.ckpt \
eval_inference_decoded \
eval_inference/llm-cxr__eval_results_0_1.pickleWe thank the authors for their great work.
- We were heavily inspired by UniXGen for how we encode images to create them bidirectionally with transformers.
- Our training pipeline was modified from the Databrick's Dolly.
- We also thank Taming Transformers for providing the architecture of
VQ-GANas image encoder and decoder.




