
Generated by DALL·E 3
Vision-Language Models (VLM) can support clinicians by analyzing medical images and engaging in natural language interactions to assist in diagnostic and treatment tasks. However, VLMs often exhibit ''hallucinogenic'' behavior, generating textual outputs not grounded in contextual multimodal information. This challenge is particularly pronounced in the medical domain, where we do not only require VLM outputs to be accurate in single interactions but also to be consistent with clinical reasoning and diagnostic pathways throughout multi-turn conversations. For this purpose, we propose a new alignment algorithm that uses symbolic representations of clinical reasoning to ground VLMs in medical knowledge. These representations are utilized to (i) generate GPT-4-guided visual instruction tuning data at scale, simulating clinician-VLM conversations with demonstrations of clinical reasoning, and (ii) create an automatic reward function that evaluates the clinical validity of VLM generations throughout clinician-VLM interactions. Our algorithm eliminates the need for human involvement in training data generation or reward model construction, reducing costs compared to standard reinforcement learning with human feedback (RLHF). We apply our alignment algorithm to develop Dr-LLaVA, a conversational VLM finetuned for analyzing bone marrow pathology slides, demonstrating strong performance in multi-turn medical conversations.
Dr-LLaVA was trained on 4 A100 GPUs with 80GB memory. For training on fewer GPUs, reduce the per_device_train_batch_size and increase the gradient_accumulation_steps accordingly, maintaining the same global batch size: per_device_train_batch_size
x gradient_accumulation_steps
x num_gpus
.
Dr-LLaVA has four steps.
- Curate the Dataset and Initialize the Policy Model with Supervised Fine-tuning
- Construct Symbolic Representations of Clinical Reasoning
- Utilize these representations to generate GPT-4-guided visual instruction tuning data, simulating clinician-VLM conversations with demonstrations of clinical reasoning.
- Create an Automatic Reward Function
- Evaluate the clinical validity of VLM outputs during clinician-VLM interactions.
- Train the RL Model with PPO
Refer to llava_setup
for instructions on setting up the customized LLaVA package.
Additionally, run the following command to ensure the versions of essential packages are correct:
pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
pip install deepspeed==0.9.3
pip install peft==0.4.0
pip install transformers==4.31.0
pip install bitsandbytes==0.41.0
pip install datasets
Note: Install PyTorch 2.0.1 following the guidelines here. The flash-attention implementation in the latest PyTorch Stable (2.1.0) may lead to buggy results. The codebase is tested with torch==2.0.1+cu118
.
Involve multi-round conversations with clinical grounding. A medical image with a known diagnosis must have other morphological features that clinicians identify prior to confirming the diagnosis.
For example:
- Medical image ->
- Question about image description ->
- Question about image quality evaluation ->
- Question about morphological feature 1 ->
- Inference on feature 1 ->
- ... ->
- Question about morphological feature n ->
- Inference on feature n ->
- Diagnosis.
Starting with labeled medical images, we use symbolic representations of clinical reasoning and GPT models to generate realistic conversations between a VLM and a clinician about the visual content of each image. These multi-turn conversations reflect various styles of clinician-VLM interactions, demonstrating accurate clinical reasoning.
Refer to LLaVA's instruction tuning data here to prepare the data in the correct format.
Note: The RL component of the model will not function unless supervised fine-tuning is performed. This method is designed to address the issue of medical knowledge hallucination, which frequently occurs during supervised fine-tuning when the model attempts to produce diagnoses but remains inconsistent across conversations. The RL component will not operate in a zero-shot manner, where an unfine-tuned model does not attempt to perform medical diagnosis.
After curating the dataset and and storing the training and test data in the LLaVA.json format, please download the the 7b SFT model checkpoint from LLaVA-RLHF-7b-v1.5-224
for supervised fine tuning, you can run the following script to initialize the policy model:
run:
cd RLHF/
bash scripts/7b-v1.5-224/initialize_policy_model.sh
Given a question about a medical image, there are limited valid responses, including an option for insufficient information. Responses can be constructed into categorical values projected onto a logical graph tree, similar to the example below for blood malignancies.
Refer to our manuscript for details on constructing the logic for hematology image diagnosis. For more examples, see the Example_Clinical_Logics.md file.
Afterward, augment the QAs with available large language models to increase the diversity of QAs.
To build your reward model, refer to the class RewardModel_HEME
in the file RLHF/models/reward_model.py
at line 444. The RewardModel_Custom
class provides a template for creating a reward model based on clinical logic. Customize the logic and rules according to your requirements.
Run:
cd RLHF/
bash scripts/7b-v1.5-224/train_rl_model.sh
If you find this repo useful for your research, please consider citing our papers:
Dr-LLaVA:
@article{sun2024dr,
title={Dr-LLaVA: Visual Instruction Tuning with Symbolic Clinical Grounding},
author={Sun, Shenghuan and Goldgof, Gregory M and Schubert, Alexander and Sun, Zhiqing and Hartvigsen, Thomas and Butte, Atul J and Alaa, Ahmed},
journal={arXiv preprint arXiv:2405.19567},
year={2024}
}
LLaVA-RLHF:
@article{sun2023aligning,
title={Aligning large multimodal models with factually augmented rlhf},
author={Sun, Zhiqing and Shen, Sheng and Cao, Shengcao and Liu, Haotian and Li, Chunyuan and Shen, Yikang and Gan, Chuang and Gui, Liang-Yan and Wang, Yu-Xiong and Yang, Yiming and others},
journal={arXiv preprint arXiv:2309.14525},
year={2023}
}
LLaVA:
@misc{liu2023llava,
title={Visual Instruction Tuning},
author={Liu, Haotian and Li, Chunyuan and Wu, Qingyang and Lee, Yong Jae},
publisher={arXiv:2304.08485},
year={2023},
}
SALMON:
@article{sun2023salmon,
title={SALMON: Self-Alignment with Principle-Following Reward Models},
author={Sun, Zhiqing and Shen, Yikang and Zhang, Hongxin and Zhou, Qinhong and Chen, Zhenfang and Cox, David and Yang, Yiming and Gan, Chuang},
journal={arXiv preprint arXiv:2310.05910},
year={2023}
}
We thank Meta LLaMA team, Stanford Alpaca team, Vicuna team, LLaVA team, QLoRA team, Hugging Face PEFT, and AlpacaFarm team for their open-source efforts in democratizing large language models.