Skip to content

Efficient connectivity-preserving loss function for training neural networks to perform instance segmentation.

License

Notifications You must be signed in to change notification settings

AllenNeuralDynamics/supervoxel-loss

Repository files navigation

Efficient Connectivity-Preserving Instance Segmentation with Supervoxel-Based Loss Function

License Code Style semantic-release: angular Interrogate Coverage Python

paper | poster

This repository implements a connectivity-preserving loss function that improves segmentation of curvilinear structures by penalizing both structure-level and voxel-level mistakes. The structure-level loss is calculated by computing supervoxels (i.e. connected components) in the false positive and false negative masks, then assigning higher penalties to critical supervoxels that introduce connectivity errors.

@inproceedings{grim2025,
  title={Efficient Connectivity-Preserving Instance Segmentation with Supervoxel-Based Loss Function},
  author={Grim, Anna and Chandrashekar, Jayaram and Sumbul, Uygar},
  booktitle={Proceedings of the AAAI conference on artificial intelligence},
  year={2025}
}

💡 If you found this useful, please consider citing our work and ⭐ starring this repository — it helps others discover it!

Method

The loss computation consists of three main steps:

1. Binarized Prediction: Prediction generated by neural network is thresholded into a binary mask to separate foreground from background.

2. False Postive/Negative Masks: Computed by comparing binarized prediction to the ground truth.

3. Critical Supervoxels: Detect connected components in false positive/negative mask that cause connectivity errors.

Finally, the loss is computed by comparing the prediction with the ground truth segmentation, applying higher penalties to voxels within critical supervoxels that affect connectivity.

pipeline
Figure: Visualization of supervoxel-based loss computation.

Installation

To use the software, in the root directory, run

pip install -e .

Usage

Here is a simple example of using this loss function to train a model.

import torch.nn as nn
import torch.optim as optim

from supervoxel_loss.loss import SuperVoxelLoss2D

    
# Initialization
model = UNet()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

loss_switch_epoch = 10
voxel_loss = nn.BCEWithLogitsLoss()
supervoxel_loss = SuperVoxelLoss2D(alpha=0.5, beta=0.5)

# Main
for epoch in range(n_epochs):
    # Set loss function based on the current epoch
    if epoch < loss_switch_epoch:
        loss_function = voxel_loss
   else:
        loss_function = supervoxel_loss

    # Training loop
    for inputs, targets in dataloader:
        # Forward pass
        preds = model(inputs)

        # Compute loss
        loss = loss_function(preds, targets)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Contact Information

For any inquiries, feedback, or contributions, please do not hesitate to contact us. You can reach us via email at [email protected] or connect on LinkedIn.

License

supervoxel-loss is licensed under the MIT License.

About

Efficient connectivity-preserving loss function for training neural networks to perform instance segmentation.

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 2

  •  
  •