Skip to content

Moses-Mk/GAN-Based-Synthetic-MRI-Augmentation

Repository files navigation

GAN-Based MRI Augmentation for Early AD Detection

This repository contains the full pipeline for synthetic MRI slice generation using GANs and CN vs EMCI classification for early Alzheimer's Disease (AD) detection. The project explores data augmentation with class-specific Wasserstein GANs with Gradient Penalty (WGAN-GP) to enhance classifier performance on subtle EMCI structural changes.


Overview

Alzheimer’s Disease (AD) is difficult to detect at early stages because structural changes in the brain are subtle, especially in the Early Mild Cognitive Impairment (EMCI) stage. To overcome limited and imbalanced datasets, this project implements a pipeline that:

  1. Trains baseline ResNet18 classifiers on real MRI slices.
  2. Trains class-specific WGAN-GP models for CN and EMCI MRI slices.
  3. Generates synthetic slices at controlled augmentation levels (10%, 20%, 30% per class).
  4. Applies automatic filtering of low-quality synthetic slices.
  5. Trains GAN-augmented classifiers under controlled ablation settings.
  6. Performs subject-level evaluation and statistical comparison (McNemar test).

All dataset splits are performed at the subject level.


Data Preprocessing

  • clean_dataset.py:
    Organizes raw ADNI NIfTI files into class directories, resolves .nii.gz inconsistencies, and validates dataset integrity.

  • Slice Extraction:
    Extracts 2D slices from 3D volumes and saves as PNGs at 256x256 resolution for 2D CNN training.

  • Dataset Class:
    MRISliceDataset handles loading, subject ID extraction, and transforms. Supports both real and synthetic slices.

  • Transforms:

    transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

GAN-Based MRI Augmentation for CN vs EMCI classification

This project implements a Wasserstein GAN with Gradient Penalty (WGAN-GP) to address class imbalance and subtlety in Early Mild Cognitive Impairment (EMCI) brain MRI slices. By generating realistic synthetic data, we improve the robustness and accuracy of a ResNet18 classifier.


🧠 Motivation

  • Subtlety: EMCI slices represent a transitional stage of Alzheimer’s and are often difficult to distinguish from Cognitively Normal (CN) slices.
  • Data Scarcity: EMCI data is frequently underrepresented in neuroimaging datasets.
  • Solution: GAN augmentation generates high-fidelity synthetic slices to diversify the training set and improve the classifier's ability to generalize.

🏗️ Architecture

Wasserstein GAN with Gradient Penalty (WGAN-GP)

  • Separate models are trained for CN and EMCI classes to ensure class-specific feature generation.

Generator:

  • Input: 128-dimensional latent vector
  • Structure: Progressive ConvTranspose2d layers scaling from 4 × 4 to 256 × 256
  • Features: BatchNorm + ReLU activations, final Tanh layer

Critic (Discriminator):

  • Input: 1-channel (grayscale) 256 × 256 slice
  • Structure: Strided Conv2d layers with InstanceNorm + LeakyReLU
  • Output: Scalar score (no Sigmoid, per WGAN-GP requirements)

Classification Model (ResNet18):

  • Baseline: Standard ResNet18 adapted for 1-channel grayscale input
  • Augmented: Same architecture trained on a combined dataset of real and filtered synthetic slices (~24,000 slices)

⚙️ Training Configuration

Hyperparameter Value
Image Size 256 × 256
Latent Dimension 128
Batch Size 32
Learning Rate 5e-5
Optimizer Adam
Critic Iterations 3 per generator step
Epochs 80

🚀 Workflow & Pipeline

1. Filtering Synthetic Slices

To ensure quality, synthetic slices are automatically passed through a filtering script based on:

  • Mean Intensity: > 0.05
  • Foreground Area: > 5% of the total image
  • Edge Strength: Laplacian variance > 15

Filtering retention varies slightly across augmentation regimes (10%, 20%, 30%) but remains stable overall, indicating consistent GAN output quality as synthetic volume increases.

2. Subject-Level Evaluation

  • Predictions are aggregated per subject by averaging individual slice probabilities to ensure clinical relevance.

Metrics:

  • Accuracy
  • Sensitivity / Specificity
  • AUC-ROC
  • Balanced Accuracy
  • Confusion Matrix

Paired statistical comparison between baseline and augmented models was conducted using McNemar’s test.

  • Visualization: Training and validation loss curves, ROC curves, and confusion matrix heatmaps

GAN-Generated MRI Samples
Figure: Example synthetic MRI slices generated by the WGAN-GP for EMCI and CN augmentation.

📈 Augmentation Ablation Study

To evaluate the effect of synthetic data proportion, experiments were conducted using:

  • 10% synthetic augmentation
  • 20% synthetic augmentation
  • 30% synthetic augmentation

Performance was reported as mean ± standard deviation across runs.

This controlled ablation allows analysis of:

  • Whether performance improvements scale with augmentation level
  • Whether gains plateau beyond a certain synthetic ratio
  • Whether augmentation remains stable as synthetic volume increases

💻 Installation

# Clone the repository
git clone https://github.com/yourusername/gan-mri-augmentation.git
cd gan-mri-augmentation 

Install dependencies

pip install -r requirements.txt

Key dependencies: torch, torchvision, hd-bet, opencv-python, streamlit, scikit-learn


🛠️ Usage

Jupyter Notebooks

  • data_pipeline.ipynb: MRI preprocessing and strict subject-level data splitting
  • cn_augmentation.ipynb / emci_augmentation.ipynb: Train class-specific WGAN-GP models for CN and EMCI slice generation and produce synthetic samples at 10%, 20%, and 30% augmentation levels
  • base_classifier.ipynb: Train the baseline ResNet18 classifier on real MRI slices
  • augmented_10.ipynb: Train the classifier with 10% synthetic augmentation
  • augmented_20.ipynb: Train the classifier with 20% synthetic augmentation
  • augmented_30.ipynb: Train the classifier with 30% synthetic augmentation

Filtering & Deployment

Filter synthetic data:

python filter_synthetic_cn.py
python filter_synthetic_emci.py

Run the Streamlit Dashboard:

streamlit run app.py

📊 Streamlit Demo

The included dashboard provides:

  • Project Introduction: Motivation and methodology
  • MRI Operations: Upload MRI slices, classify them in real-time, and visualize GAN-generated augmentations

⚠️ Research Disclaimer:
This system is intended strictly for research and educational purposes. Results are based on limited cohort size and internal validation only. The model is not approved for clinical diagnosis or medical decision-making.


📚 References

  • ADNI Database: adni.loni.usc.edu
  • WGAN-GP: Gulrajani et al., 2017
  • HD-BET: Brain Extraction Tool for skull stripping
  • ResNet: He et al., 2015

📄 License

This repository is for research and educational purposes only. Please cite the ADNI database and relevant papers when utilizing this code.

About

GAN-based MRI augmentation for early detection of Alzheimer’s disease

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors