This repository implements a connectivity-preserving loss function that improves segmentation of curvilinear structures by penalizing both structure-level and voxel-level mistakes. The structure-level loss is calculated by computing supervoxels (i.e. connected components) in the false positive and false negative masks, then assigning higher penalties to critical supervoxels that introduce connectivity errors.
@inproceedings{grim2025,
title={Efficient Connectivity-Preserving Instance Segmentation with Supervoxel-Based Loss Function},
author={Grim, Anna and Chandrashekar, Jayaram and Sumbul, Uygar},
booktitle={Proceedings of the AAAI conference on artificial intelligence},
year={2025}
}
💡 If you found this useful, please consider citing our work and ⭐ starring this repository — it helps others discover it!
The loss computation consists of three main steps:
1. Binarized Prediction: Prediction generated by neural network is thresholded into a binary mask to separate foreground from background.
2. False Postive/Negative Masks: Computed by comparing binarized prediction to the ground truth.
3. Critical Supervoxels: Detect connected components in false positive/negative mask that cause connectivity errors.
Finally, the loss is computed by comparing the prediction with the ground truth segmentation, applying higher penalties to voxels within critical supervoxels that affect connectivity.
Figure: Visualization of supervoxel-based loss computation.
To use the software, in the root directory, run
pip install -e .
Here is a simple example of using this loss function to train a model.
import torch.nn as nn
import torch.optim as optim
from supervoxel_loss.loss import SuperVoxelLoss2D
# Initialization
model = UNet()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
loss_switch_epoch = 10
voxel_loss = nn.BCEWithLogitsLoss()
supervoxel_loss = SuperVoxelLoss2D(alpha=0.5, beta=0.5)
# Main
for epoch in range(n_epochs):
# Set loss function based on the current epoch
if epoch < loss_switch_epoch:
loss_function = voxel_loss
else:
loss_function = supervoxel_loss
# Training loop
for inputs, targets in dataloader:
# Forward pass
preds = model(inputs)
# Compute loss
loss = loss_function(preds, targets)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
For any inquiries, feedback, or contributions, please do not hesitate to contact us. You can reach us via email at [email protected] or connect on LinkedIn.
supervoxel-loss is licensed under the MIT License.