Skip to content

WGLab/CoT-RAG-LLM-Gene-Prioritization-Disease-Diagnosis

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

29 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoT–RAG LLM for Disease Diagnosis & Gene Prioritization

This repository contains the official implementation of a CoT–RAG framework that supports both disease diagnosis and gene prioritization across five inference methods. It includes end-to-end inference scripts, retrieval components (FAISS + biomedical embeddings + reranking), and an automated evaluator.

Paper (arXiv): https://arxiv.org/abs/2503.12286


Datasets Description

All the publicly available datasets used for evaluation are in the dataset folder, including all the sythetic clinical notes generated by us using GPT-4 from data at Phenopacket-Store and the selected cohort of 5,980 clinical notes used for testing. Additionally, we include the pubmed_free_text with 255 literature-derived clinical notes, which is originally compiled in LLM-Gene-Prioritization for this paper.

Phenopacket-derived clinical notes were synthesized by us in ChatGPT with a specific prompting strategy. If you use the Phenopacket-derived clinical notes in your studies, please cite both the Phenopacket paper and our paper.

Methods

All five methods can be run for disease diagnosis (Top-10 diseases) or gene prioritization (Top-10 genes) by switching the task prompt / extraction function where applicable.

1) Base Model

Script: main_script/RareDxGPT_inference_vllm.py
A vLLM-based script for running base models.

2) CoT (Chain-of-Thought)

Script: main_script/RareDxGPT_inference_CoT.py
Single-pass CoT prompting with a strict Top-10 output format (e.g., POTENTIAL_DISEASES / POTENTIAL_GENES). Uses vLLM for generation.

3) RAG (Retrieval-Augmented Generation)

Script: main_script/RareDxGPT_inference_RAG.py
Retrieves knowledge from a FAISS index using PubMedBERT embeddings, optionally reranks with ColBERTv2, and generates a grounded Top-10 list.

4) CoT-driven RAG (CoT → Retrieve → Finalize)

Script: main_script/RareDxGPT_inference_CoT_driven_RAG.py
Runs multi-step reasoning first, uses intermediate reasoning as the retrieval query, then finalizes the ranked Top-10 list with retrieved evidence.

5) RAG-driven CoT (Retrieve → CoT reasoning)

Script: main_script/RareDxGPT_inference_RAG_driven_CoT.py
Retrieves first, injects retrieved evidence into a CoT-style reasoning prompt, then outputs the Top-10 list.


Repository layout

.
├── AutoEvaluator/                  # evaluation utilities
├── dataset/                        # (optional) local datasets or dataset scripts
├── main_script/                    # inference entrypoints (five methods)
├── utils/                          # helper utilities (seed, extraction, etc.)
├── requirements.txt
├── LICENSE
└── README.md

Installation

git clone https://github.com/WGLab/CoT-RAG-LLM-Gene-Prioritization-Disease-Diagnosis.git
cd CoT-RAG-LLM-Gene-Prioritization-Disease-Diagnosis

conda create -n cotrag python=3.10 -y
conda activate cotrag

pip install -U pip
pip install -r requirements.txt

Note: vllm and FAISS may require CUDA / platform-specific installation. If faiss-gpu is not available for your platform, use faiss-cpu.


Data & paths

The original scripts contained hard-coded absolute paths (e.g., /home/...).
This repo version expects relative paths by default and exposes common paths as CLI arguments.

Typical inputs:

  • Clinical note dataset: a HuggingFace dataset directory loaded via datasets.load_from_disk(...)
  • Reference file: reference_data/disease_name_full.csv (or your own list)
  • FAISS index: a local directory created previously for retrieval

Usage

Each script now supports a consistent set of path arguments. Examples below assume:

  • dataset directory: datasets/CoTRAG_clinical_notes
  • FAISS index directory: datasets/rag_embedding
  • reference CSV: reference_data/disease_name_full.csv
  • outputs: outputs/

1) Base Model

python main_script/RareDxGPT_inference_vllm.py \
  --data_dir datasets/bws \
  --reference_csv reference_data/disease_name_full.csv \
  --output_dir outputs/vllm \
  --task disease

2) CoT

python main_script/RareDxGPT_inference_CoT.py \
  --data_dir datasets/CoTRAG_clinical_notes \
  --reference_csv reference_data/disease_name_full.csv \
  --output_dir outputs/cot \
  --task disease

3) RAG

python main_script/RareDxGPT_inference_RAG.py \
  --data_dir datasets/CoTRAG_clinical_notes \
  --index_dir datasets/rag_embedding \
  --reference_csv reference_data/disease_name_full.csv \
  --output_dir outputs/rag \
  --task gene

4) CoT-driven RAG

python main_script/RareDxGPT_inference_CoT_driven_RAG.py \
  --data_dir datasets/CoTRAG_clinical_notes \
  --index_dir datasets/rag_embedding \
  --reference_csv reference_data/disease_name_full.csv \
  --output_dir outputs/cot_driven_rag \
  --task gene

5) RAG-driven CoT

python main_script/RareDxGPT_inference_RAG_driven_CoT.py \
  --data_dir datasets/CoTRAG_clinical_notes \
  --index_dir datasets/rag_embedding \
  --reference_csv reference_data/disease_name_full.csv \
  --output_dir outputs/rag_driven_cot \
  --task disease

Outputs

All scripts write:

  • a raw generation CSV in --output_dir
  • extracted Top-10 lists (and evaluator summaries if enabled)

Citation

If you use this codebase, please cite:

@misc{wang2026integratingchainofthoughtretrievalaugmented,
      title={Integrating Chain-of-Thought and Retrieval Augmented Generation Enhances Rare Disease Diagnosis from Clinical Notes}, 
      author={Zhanliang Wang and Da Wu and Quan Nguyen and Kai Wang},
      year={2026},
      eprint={2503.12286},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2503.12286}, 
}

License

MIT License. See LICENSE.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages