|
1 | 1 | <img src="./med-seg-diff.png" width="450px"></img> |
2 | 2 |
|
3 | | -## MedSegDiff - Pytorch (wip) |
| 3 | +## MedSegDiff - Pytorch |
4 | 4 |
|
5 | 5 | Implementation of <a href="https://arxiv.org/abs/2211.00611">MedSegDiff</a> in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space. |
6 | 6 |
|
7 | 7 | I will also add attention and introduce an extended type of cross modulation on the attention matrices, alphafold2 style. |
8 | 8 |
|
| 9 | +## Install |
| 10 | + |
| 11 | +```bash |
| 12 | +$ pip install med-seg-diff-pytorch |
| 13 | +``` |
| 14 | + |
| 15 | +## Usage |
| 16 | + |
| 17 | +```python |
| 18 | +import torch |
| 19 | +from med_seg_diff_pytorch import Unet, MedSegDiff |
| 20 | + |
| 21 | +model = Unet( |
| 22 | + dim = 64, |
| 23 | + dim_mults = (1, 2, 4, 8) |
| 24 | +) |
| 25 | + |
| 26 | +diffusion = MedSegDiff( |
| 27 | + model, |
| 28 | + image_size = 128, |
| 29 | + timesteps = 1000 |
| 30 | +).cuda() |
| 31 | + |
| 32 | +segmented_imgs = torch.rand(8, 3, 128, 128) # inputs are normalized from 0 to 1 |
| 33 | +input_imgs = torch.rand(8, 3, 128, 128) |
| 34 | + |
| 35 | +loss = diffusion(segmented_imgs, input_imgs) |
| 36 | +loss.backward() |
| 37 | + |
| 38 | +# after a lot of training |
| 39 | + |
| 40 | +pred = diffusion.sample(input_imgs) # pass in your unsegmented images |
| 41 | +pred.shape # predicted segmented images - (8, 3, 128, 128) |
| 42 | +``` |
| 43 | + |
9 | 44 | ## Appreciation |
10 | 45 |
|
11 | 46 | - <a href="https://stability.ai/">StabilityAI</a> for the generous sponsorship, as well as my other sponsors out there |
12 | 47 |
|
| 48 | +## Todo |
| 49 | + |
| 50 | +- [ ] add a cross attention variant for generating the attentive map (A) |
| 51 | +- [ ] modulate attention matrices in middle and other self attention layers, wherever full attention is used |
| 52 | +- [ ] some basic training code, with Trainer taking in custom dataset tailored for medical image formats |
| 53 | + |
13 | 54 | ## Citations |
14 | 55 |
|
15 | 56 | ```bibtex |
|
0 commit comments