-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathREADME
More file actions
201 lines (136 loc) · 7.52 KB
/
README
File metadata and controls
201 lines (136 loc) · 7.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
# GAN-Based MRI Augmentation for Early AD Detection
This repository contains the full pipeline for **synthetic MRI slice generation using GANs** and **CN vs EMCI classification** for early Alzheimer's Disease (AD) detection. The project explores **data augmentation with class-specific Wasserstein GANs with Gradient Penalty (WGAN-GP)** to enhance classifier performance on subtle EMCI structural changes.
---
## Overview
Alzheimer’s Disease (AD) is difficult to detect at early stages because structural changes in the brain are subtle, especially in the **Early Mild Cognitive Impairment (EMCI)** stage. To overcome limited and imbalanced datasets, this project implements a pipeline that:
1. Trains **baseline ResNet18 classifiers** on real MRI slices.
2. Trains **class-specific WGAN-GP models** for CN and EMCI MRI slices.
3. Generates synthetic slices at controlled augmentation levels (10%, 20%, 30% per class).
4. Applies automatic filtering of low-quality synthetic slices.
5. Trains GAN-augmented classifiers under controlled ablation settings.
6. Performs subject-level evaluation and statistical comparison (McNemar test).
All dataset splits are performed at the **subject level**.
---
## Data Preprocessing
- **`clean_dataset.py`**:
Organizes raw ADNI NIfTI files into class directories, resolves `.nii.gz` inconsistencies, and validates dataset integrity.
- **Slice Extraction**:
Extracts 2D slices from 3D volumes and saves as PNGs at 256x256 resolution for 2D CNN training.
- **Dataset Class**:
`MRISliceDataset` handles loading, subject ID extraction, and transforms. Supports both real and synthetic slices.
- **Transforms**:
```python
transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
# GAN-Based MRI Augmentation for CN vs EMCI classification
This project implements a **Wasserstein GAN with Gradient Penalty (WGAN-GP)** to address class imbalance and subtlety in Early Mild Cognitive Impairment (EMCI) brain MRI slices. By generating realistic synthetic data, we improve the robustness and accuracy of a **ResNet18 classifier**.
---
## 🧠 Motivation
- **Subtlety**: EMCI slices represent a transitional stage of Alzheimer’s and are often difficult to distinguish from Cognitively Normal (CN) slices.
- **Data Scarcity**: EMCI data is frequently underrepresented in neuroimaging datasets.
- **Solution**: GAN augmentation generates high-fidelity synthetic slices to diversify the training set and improve the classifier's ability to generalize.
---
## 🏗️ Architecture
### Wasserstein GAN with Gradient Penalty (WGAN-GP)
- Separate models are trained for **CN** and **EMCI** classes to ensure class-specific feature generation.
**Generator:**
- **Input**: 128-dimensional latent vector
- **Structure**: Progressive `ConvTranspose2d` layers scaling from `4 × 4` to `256 × 256`
- **Features**: BatchNorm + ReLU activations, final Tanh layer
**Critic (Discriminator):**
- **Input**: 1-channel (grayscale) `256 × 256` slice
- **Structure**: Strided `Conv2d` layers with InstanceNorm + LeakyReLU
- **Output**: Scalar score (no Sigmoid, per WGAN-GP requirements)
**Classification Model (ResNet18):**
- **Baseline**: Standard ResNet18 adapted for 1-channel grayscale input
- **Augmented**: Same architecture trained on a combined dataset of real and filtered synthetic slices (~24,000 slices)
---
## ⚙️ Training Configuration
| Hyperparameter | Value |
|------------------------|-----------------------------|
| Image Size | 256 × 256 |
| Latent Dimension | 128 |
| Batch Size | 32 |
| Learning Rate | 5e-5 |
| Optimizer | Adam |
| Critic Iterations | 3 per generator step |
| Epochs | 80 |
---
## 🚀 Workflow & Pipeline
### 1. Filtering Synthetic Slices
To ensure quality, synthetic slices are automatically passed through a filtering script based on:
- **Mean Intensity**: > 0.05
- **Foreground Area**: > 5% of the total image
- **Edge Strength**: Laplacian variance > 15
Filtering retention varies slightly across augmentation regimes (10%, 20%, 30%) but remains stable overall, indicating consistent GAN output quality as synthetic volume increases.
### 2. Subject-Level Evaluation
- Predictions are aggregated **per subject** by averaging individual slice probabilities to ensure clinical relevance.
**Metrics:**
- Accuracy
- Sensitivity / Specificity
- AUC-ROC
- Balanced Accuracy
- Confusion Matrix
Paired statistical comparison between baseline and augmented models was conducted using **McNemar’s test**.
- **Visualization**: Training and validation loss curves, ROC curves, and confusion matrix heatmaps
---

*Figure: Example synthetic MRI slices generated by the WGAN-GP for EMCI and CN augmentation.*
## 📈 Augmentation Ablation Study
To evaluate the effect of synthetic data proportion, experiments were conducted using:
- **10% synthetic augmentation**
- **20% synthetic augmentation**
- **30% synthetic augmentation**
Performance was reported as **mean ± standard deviation across runs**.
This controlled ablation allows analysis of:
- Whether performance improvements scale with augmentation level
- Whether gains plateau beyond a certain synthetic ratio
- Whether augmentation remains stable as synthetic volume increases
## 💻 Installation
```bash
# Clone the repository
git clone https://github.com/yourusername/gan-mri-augmentation.git
cd gan-mri-augmentation
```
# Install dependencies
```bash
pip install -r requirements.txt
```
**Key dependencies:** `torch`, `torchvision`, `hd-bet`, `opencv-python`, `streamlit`, `scikit-learn`
---
## 🛠️ Usage
### Jupyter Notebooks
- `data_pipeline.ipynb`: MRI preprocessing and strict subject-level data splitting
- `cn_augmentation.ipynb` / `emci_augmentation.ipynb`: Train class-specific WGAN-GP models for CN and EMCI slice generation and produce synthetic samples at 10%, 20%, and 30% augmentation levels
- `base_classifier.ipynb`: Train the baseline ResNet18 classifier on real MRI slices
- `augmented_10.ipynb`: Train the classifier with 10% synthetic augmentation
- `augmented_20.ipynb`: Train the classifier with 20% synthetic augmentation
- `augmented_30.ipynb`: Train the classifier with 30% synthetic augmentation
### Filtering & Deployment
**Filter synthetic data:**
```bash
python filter_synthetic_cn.py
python filter_synthetic_emci.py
```
**Run the Streamlit Dashboard:**
```bash
streamlit run app.py
```
## 📊 Streamlit Demo
The included dashboard provides:
- **Project Introduction:** Motivation and methodology
- **MRI Operations:** Upload MRI slices, classify them in real-time, and visualize GAN-generated augmentations
⚠️ **Research Disclaimer:**
This system is intended strictly for research and educational purposes. Results are based on limited cohort size and internal validation only. The model is not approved for clinical diagnosis or medical decision-making.
---
## 📚 References
- **ADNI Database:** [adni.loni.usc.edu](http://adni.loni.usc.edu)
- **WGAN-GP:** Gulrajani et al., 2017
- **HD-BET:** Brain Extraction Tool for skull stripping
- **ResNet:** He et al., 2015
---
## 📄 License
This repository is for research and educational purposes only. Please cite the ADNI database and relevant papers when utilizing this code.