This project implements a retinal vessel segmentation pipeline on the DRIVE dataset using the Swin-UNet (SwinUNETR) architecture.
Overcoming the lack of inductive bias inherent to Transformers and the extreme challenge of limited training data (only 20 images), this project successfully achieved a Dice Score of 0.7302. The model established a highly robust clinical-grade baseline through the following key engineering optimizations:
- Pre-trained Weights Mapping: Successfully mapped Microsoft's official Swin-Tiny weights to the UNETR architecture.
- Distribution Alignment: Standard ImageNet normalization to prevent covariate shift.
- Optimization Strategy: AdamW optimizer combined with Differential Learning Rates and Gradient Accumulation.
-
Loss Function: A rebalanced Hybrid Loss (
$0.5 \times BCE + 0.5 \times Dice$ ). - Post-processing: Physical FOV masking and rigorous threshold sweeping.
Training a Transformer from scratch on 20 images fails to converge. I implemented a custom weight mapping function to load swin_tiny_patch4_window7_224.pth into the encoder. To protect these weights, inputs were strictly normalized using ImageNet statistics (Mean: [0.485, 0.456, 0.406], Std: [0.229, 0.224, 0.225]).
Standard RMSprop caused severe gradient explosions (NaN) in the Swin architecture. This was mitigated by switching to AdamW with a lower weight decay (batch_size=1, I introduced Gradient Accumulation (accumulation_steps=4) to simulate a larger batch size and stabilize the loss landscape.
A global learning rate either destroyed pretrained features or starved the randomly initialized decoder. A differential strategy was applied:
To address severe class imbalance and gradient oscillations caused by the non-convex Dice Loss, the final objective function was rebalanced:
(Input Image; Ground Truth; Swin-UNet Prediction; Overlay Analysis)
The following table demonstrates the step-by-step optimization process:
| Exp ID | Strategy | Key Modification | Train Loss | Dice Score | Observation |
|---|---|---|---|---|---|
| Baseline | Swin-UNet (Scratch) | Initialized from scratch | 0.312 | < 0.50 | Failed to converge due to lack of inductive bias. |
| Exp 1 | + Pretrained Weights | Official Swin-Tiny | 0.297 | 0.7465* | False prosperity (over-segmentation); distribution shifted. |
| Exp 2 | + Normalization | ImageNet standard | 0.284 | ~0.60 | Features protected, but Decoder starved due to low LR. |
| Exp 3 | + Diff LR | Enc |
NaN | N/A | RMSprop incompatible with Transformer; gradient exploded. |
| Exp 4 | + AdamW | Optimizer swap | 0.451 | ~0.60 | Stuck in local minimum (predicted all background). |
| Exp 5 | + Grad Accumulation | accumulation_steps=4 |
0.385 | 0.6384 | FOV boundary generated false-positive rings. |
| Exp 6 | + FOV Mask | Bitwise multiplication | 0.385 | 0.7012 | Boundary artifacts physically eliminated. |
| Exp 7 | + Loss Rebalance | 0.5 BCE + 0.5 Dice | 0.329 | 0.7253 | Oscillation suppressed; vessel fragmentation fixed. |
| Final | + Threshold Tuning | Threshold = 0.45 | 0.329 | 0.7302 | Achieved optimal Precision-Recall balance. |
(Note: The early 0.7465 score was an artifact of severe over-segmentation and lacks clinical validity.)
- Dice Score (F1): 0.7302
- Sensitivity (Recall): 0.7419
- Specificity: 0.9716
- Accuracy: 0.9511
Prerequisites
pip install -r requirements.txtpython train.py --epochs 50 --batch-size 1 --learning-rate 1e-4 --classes 1 --amppython predict.py -i data/test_imgs/01.png -o result_01.png -m checkpoints/checkpoint_epoch50.pth --fov-dir data/mask --mask-threshold 0.45