Skip to content

lh1327348455-lab/Swin-UNet-Retinal-Segmentation

Repository files navigation

Retinal Vessel Segmentation based on Optimized Swin-UNet

1. Project Overview

This project implements a retinal vessel segmentation pipeline on the DRIVE dataset using the Swin-UNet (SwinUNETR) architecture.

Overcoming the lack of inductive bias inherent to Transformers and the extreme challenge of limited training data (only 20 images), this project successfully achieved a Dice Score of 0.7302. The model established a highly robust clinical-grade baseline through the following key engineering optimizations:

  • Pre-trained Weights Mapping: Successfully mapped Microsoft's official Swin-Tiny weights to the UNETR architecture.
  • Distribution Alignment: Standard ImageNet normalization to prevent covariate shift.
  • Optimization Strategy: AdamW optimizer combined with Differential Learning Rates and Gradient Accumulation.
  • Loss Function: A rebalanced Hybrid Loss ($0.5 \times BCE + 0.5 \times Dice$).
  • Post-processing: Physical FOV masking and rigorous threshold sweeping.

2. Methodology & Improvements

2.1 Pretrained Weights & Distribution Alignment

Training a Transformer from scratch on 20 images fails to converge. I implemented a custom weight mapping function to load swin_tiny_patch4_window7_224.pth into the encoder. To protect these weights, inputs were strictly normalized using ImageNet statistics (Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]).

2.2 Optimizer & Gradient Accumulation

Standard RMSprop caused severe gradient explosions (NaN) in the Swin architecture. This was mitigated by switching to AdamW with a lower weight decay ($10^{-5}$). To counter the high gradient variance caused by batch_size=1, I introduced Gradient Accumulation (accumulation_steps=4) to simulate a larger batch size and stabilize the loss landscape.

2.3 Differential Learning Rates

A global learning rate either destroyed pretrained features or starved the randomly initialized decoder. A differential strategy was applied: $10^{-4}$ for the pre-trained Encoder, and $5 \times 10^{-4}$ for the Decoder.

2.4 Hybrid Loss Function

To address severe class imbalance and gradient oscillations caused by the non-convex Dice Loss, the final objective function was rebalanced: $$Loss = 0.5 \times BCEWithLogitsLoss(pos_weight=7.0) + 0.5 \times DiceLoss$$


3. Results & Ablation Study

Visual Comparison

(Input Image; Ground Truth; Swin-UNet Prediction; Overlay Analysis)

Swin-UNet Final Results

Ablation Study (Experimental Log)

The following table demonstrates the step-by-step optimization process:

Exp ID Strategy Key Modification Train Loss Dice Score Observation
Baseline Swin-UNet (Scratch) Initialized from scratch 0.312 < 0.50 Failed to converge due to lack of inductive bias.
Exp 1 + Pretrained Weights Official Swin-Tiny 0.297 0.7465* False prosperity (over-segmentation); distribution shifted.
Exp 2 + Normalization ImageNet standard 0.284 ~0.60 Features protected, but Decoder starved due to low LR.
Exp 3 + Diff LR Enc $10^{-5}$ / Dec $10^{-4}$ NaN N/A RMSprop incompatible with Transformer; gradient exploded.
Exp 4 + AdamW Optimizer swap 0.451 ~0.60 Stuck in local minimum (predicted all background).
Exp 5 + Grad Accumulation accumulation_steps=4 0.385 0.6384 FOV boundary generated false-positive rings.
Exp 6 + FOV Mask Bitwise multiplication 0.385 0.7012 Boundary artifacts physically eliminated.
Exp 7 + Loss Rebalance 0.5 BCE + 0.5 Dice 0.329 0.7253 Oscillation suppressed; vessel fragmentation fixed.
Final + Threshold Tuning Threshold = 0.45 0.329 0.7302 Achieved optimal Precision-Recall balance.

(Note: The early 0.7465 score was an artifact of severe over-segmentation and lacks clinical validity.)

Final Metrics (at Optimal Threshold 0.45)

  • Dice Score (F1): 0.7302
  • Sensitivity (Recall): 0.7419
  • Specificity: 0.9716
  • Accuracy: 0.9511

4. Quick Start

Prerequisites

pip install -r requirements.txt

Training

python train.py --epochs 50 --batch-size 1 --learning-rate 1e-4 --classes 1 --amp

Inference / Testing

python predict.py -i data/test_imgs/01.png -o result_01.png -m checkpoints/checkpoint_epoch50.pth --fov-dir data/mask --mask-threshold 0.45

About

Retinal Vessel Segmentation on DRIVE dataset using optimized Swin-UNet.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages