This repository contains the official implementation of a CoT–RAG framework that supports both disease diagnosis and gene prioritization across five inference methods. It includes end-to-end inference scripts, retrieval components (FAISS + biomedical embeddings + reranking), and an automated evaluator.
Paper (arXiv): https://arxiv.org/abs/2503.12286
All the publicly available datasets used for evaluation are in the dataset folder, including all the sythetic clinical notes generated by us using GPT-4 from data at Phenopacket-Store and the selected cohort of 5,980 clinical notes used for testing. Additionally, we include the pubmed_free_text with 255 literature-derived clinical notes, which is originally compiled in LLM-Gene-Prioritization for this paper.
Phenopacket-derived clinical notes were synthesized by us in ChatGPT with a specific prompting strategy. If you use the Phenopacket-derived clinical notes in your studies, please cite both the Phenopacket paper and our paper.
All five methods can be run for disease diagnosis (Top-10 diseases) or gene prioritization (Top-10 genes) by switching the task prompt / extraction function where applicable.
Script: main_script/RareDxGPT_inference_vllm.py
A vLLM-based script for running base models.
Script: main_script/RareDxGPT_inference_CoT.py
Single-pass CoT prompting with a strict Top-10 output format (e.g., POTENTIAL_DISEASES / POTENTIAL_GENES). Uses vLLM for generation.
Script: main_script/RareDxGPT_inference_RAG.py
Retrieves knowledge from a FAISS index using PubMedBERT embeddings, optionally reranks with ColBERTv2, and generates a grounded Top-10 list.
Script: main_script/RareDxGPT_inference_CoT_driven_RAG.py
Runs multi-step reasoning first, uses intermediate reasoning as the retrieval query, then finalizes the ranked Top-10 list with retrieved evidence.
Script: main_script/RareDxGPT_inference_RAG_driven_CoT.py
Retrieves first, injects retrieved evidence into a CoT-style reasoning prompt, then outputs the Top-10 list.
.
├── AutoEvaluator/ # evaluation utilities
├── dataset/ # (optional) local datasets or dataset scripts
├── main_script/ # inference entrypoints (five methods)
├── utils/ # helper utilities (seed, extraction, etc.)
├── requirements.txt
├── LICENSE
└── README.md
git clone https://github.com/WGLab/CoT-RAG-LLM-Gene-Prioritization-Disease-Diagnosis.git
cd CoT-RAG-LLM-Gene-Prioritization-Disease-Diagnosis
conda create -n cotrag python=3.10 -y
conda activate cotrag
pip install -U pip
pip install -r requirements.txtNote:
vllmand FAISS may require CUDA / platform-specific installation. Iffaiss-gpuis not available for your platform, usefaiss-cpu.
The original scripts contained hard-coded absolute paths (e.g., /home/...).
This repo version expects relative paths by default and exposes common paths as CLI arguments.
Typical inputs:
- Clinical note dataset: a HuggingFace dataset directory loaded via
datasets.load_from_disk(...) - Reference file:
reference_data/disease_name_full.csv(or your own list) - FAISS index: a local directory created previously for retrieval
Each script now supports a consistent set of path arguments. Examples below assume:
- dataset directory:
datasets/CoTRAG_clinical_notes - FAISS index directory:
datasets/rag_embedding - reference CSV:
reference_data/disease_name_full.csv - outputs:
outputs/
python main_script/RareDxGPT_inference_vllm.py \
--data_dir datasets/bws \
--reference_csv reference_data/disease_name_full.csv \
--output_dir outputs/vllm \
--task diseasepython main_script/RareDxGPT_inference_CoT.py \
--data_dir datasets/CoTRAG_clinical_notes \
--reference_csv reference_data/disease_name_full.csv \
--output_dir outputs/cot \
--task diseasepython main_script/RareDxGPT_inference_RAG.py \
--data_dir datasets/CoTRAG_clinical_notes \
--index_dir datasets/rag_embedding \
--reference_csv reference_data/disease_name_full.csv \
--output_dir outputs/rag \
--task genepython main_script/RareDxGPT_inference_CoT_driven_RAG.py \
--data_dir datasets/CoTRAG_clinical_notes \
--index_dir datasets/rag_embedding \
--reference_csv reference_data/disease_name_full.csv \
--output_dir outputs/cot_driven_rag \
--task genepython main_script/RareDxGPT_inference_RAG_driven_CoT.py \
--data_dir datasets/CoTRAG_clinical_notes \
--index_dir datasets/rag_embedding \
--reference_csv reference_data/disease_name_full.csv \
--output_dir outputs/rag_driven_cot \
--task diseaseAll scripts write:
- a raw generation CSV in
--output_dir - extracted Top-10 lists (and evaluator summaries if enabled)
If you use this codebase, please cite:
@misc{wang2026integratingchainofthoughtretrievalaugmented,
title={Integrating Chain-of-Thought and Retrieval Augmented Generation Enhances Rare Disease Diagnosis from Clinical Notes},
author={Zhanliang Wang and Da Wu and Quan Nguyen and Kai Wang},
year={2026},
eprint={2503.12286},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2503.12286},
}MIT License. See LICENSE.