MambaMIM: Pre-training Mamba with State Space Token Interpolation and its Application to Medical Image Segmentation
2 Suzhou Institute for Advanced Research, University of Science and Technology of China
3 Department of Automation, Institute of Image Processing and Pattern Recognition, Shanghai Jiao Tong University
- MambaMIM accepted by Medical Image Analysis (MedIA'25) ! 🥰
- Weights released ! 😎
- Code released ! 😘
- Code and weights will be released soon ! 😘
- Paper released (2024/08/16) !
- Paper released
- Code released
- Weight released
| Name | Resolution | Intensities | Spacing | Weights |
|---|---|---|---|---|
| MambaMIM | 96x96x96 | [-175, 250] | 1.5x1.5x1.5 mm | Google Drive (87MB) |
conda create -n mambamim python=3.9
conda activate mambamim
pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu117
pip install packaging timm==0.5.4
pip install transformers==4.34.1 typed-argument-parser
pip install numpy==1.21.2 opencv-python==4.5.5.64 opencv-python-headless==4.5.5.64
pip install 'monai[all]'
pip install monai==1.2.0
pip install causal_conv1d-1.2.0.post2+cu118torch1.13cxx11abiTRUE-cp38-cp38-linux_x86_64.whl
pip install mamba_ssm-1.2.0.post1+cu118torch1.13cxx11abiFALSE-cp38-cp38-linux_x86_64.whl
We recommend that you convert the dataset into the nnUNet format.
└── MambaMIM
├── data
├── Dataset060_TotalSegmentator
└── imagesTr
├── xxx_0000.nii.gz
├── ...
├── Dataset006_FLARE2022
└── imagesTr
├── xxx_0000.nii.gz
├── ...
└── Other_dataset
└── imagesTr
├── xxx_0000.nii.gz
├── ...
An example dataset.json will be generated in ./data
The content should be like below:
{
"training": [
{
"image": "./Dataset060_TotalSegmentator/imagesTr/xxx_0000.nii.gz"
},
{
"image": "./Dataset006_FLARE2022/imagesTr/xxx_0000.nii.gz"
},
]
}
Run training on multi-GPU :
# An example of training on 4 GPUs with DDP
torchrun --nproc_per_node=4 --nnodes=1 --node_rank=0 --master_addr=localhost --master_port=12351 main.py --exp_name=debug --data_path=./data --model=mambamim --bs=16 --exp_dir=debug_mambamim_ddp_4Run training on the single-GPU :
# An example of training on the single GPU
python main.py --exp_name=debug --data_path=./data --model=mambamim --bs=4 --exp_dir=debug_mambamimLoad pre-training weights :
# An example of Fine-tuning on BTCV (num_classes=14)
from models.network.hymamba import build_hybird
model = build_hybird(in_channel=1, n_classes=14, img_size=96).cuda()
model_dict = torch.load("mambamim_mask75.pth")
if model.load_state_dict(model_dict, strict=False):
print("MambaMIM use pretrained weights successfully !")Downstream pipeline can be referred to UNETR.
This code uses helper functions from SparK and HySparK.
If the code, paper and weights help your research, please cite:
@article{tang2025mambamim,
title={MambaMIM: Pre-training Mamba with state space token interpolation and its application to medical image segmentation},
author={Tang, Fenghe and Nian, Bingkun and Li, Yingtai and Jiang, Zihang and Yang, Jie and Liu, Wei and Zhou, S Kevin},
journal={Medical Image Analysis},
pages={103606},
year={2025},
publisher={Elsevier}
}
This project is released under the Apache 2.0 license. Please see the LICENSE file for more information.

