Skip to content

Commit fd6ba36

Browse files
committed
fix readme, add some better asserts, allow for passing in mask as (b, h, w) shape, and also plan on putting more work into the repository now that a researcher has seen good results
1 parent 378ec4e commit fd6ba36

File tree

3 files changed

+20
-5
lines changed

3 files changed

+20
-5
lines changed

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ from med_seg_diff_pytorch import Unet, MedSegDiff
2020
model = Unet(
2121
dim = 64,
2222
image_size = 128,
23+
mask_channels = 1, # segmentation has 1 channel
24+
input_img_channels = 3, # input images have 3 channels
2325
dim_mults = (1, 2, 4, 8)
2426
)
2527

@@ -28,7 +30,7 @@ diffusion = MedSegDiff(
2830
timesteps = 1000
2931
).cuda()
3032

31-
segmented_imgs = torch.rand(8, 3, 128, 128) # inputs are normalized from 0 to 1
33+
segmented_imgs = torch.rand(8, 1, 128, 128) # inputs are normalized from 0 to 1
3234
input_imgs = torch.rand(8, 3, 128, 128)
3335

3436
loss = diffusion(segmented_imgs, input_imgs)
@@ -56,6 +58,8 @@ If you want to add in self condition where we condition with the mask we have so
5658

5759
- [x] some basic training code, with Trainer taking in custom dataset tailored for medical image formats - thanks to <a href="https://github.com/isamu-isozaki">@isamu-isozaki</a>
5860

61+
- [ ] full blown transformer of any depth in the middle, as done in <a href="https://arxiv.org/abs/2301.11093">simple diffusion</a>
62+
5963
## Citations
6064

6165
```bibtex

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def __init__(
238238
self,
239239
dim,
240240
image_size,
241-
mask_channels=1,
242-
input_img_channels=3,
241+
mask_channels = 1,
242+
input_img_channels = 3,
243243
init_dim = None,
244244
out_dim = None,
245245
dim_mults: tuple = (1, 2, 4, 8),
@@ -258,6 +258,7 @@ def __init__(
258258
self.input_img_channels = input_img_channels
259259
self.mask_channels = mask_channels
260260
self.self_condition = self_condition
261+
261262
output_channels = mask_channels
262263
mask_channels = mask_channels * (2 if self_condition else 1)
263264

@@ -699,11 +700,21 @@ def p_losses(self, x_start, t, cond, noise = None):
699700
return F.mse_loss(model_out, target)
700701

701702
def forward(self, img, cond_img, *args, **kwargs):
703+
if img.ndim == 3:
704+
img = rearrange(img, 'b h w -> b 1 h w')
705+
706+
if cond_img.ndim == 3:
707+
cond_img = rearrange(cond_img, 'b h w -> b 1 h w')
708+
702709
device = self.device
703710
img, cond_img = img.to(device), cond_img.to(device)
704711

705-
b, c, h, w, device, img_size, = *img.shape, img.device, self.image_size
712+
b, c, h, w, device, img_size, img_channels, mask_channels = *img.shape, img.device, self.image_size, self.input_img_channels, self.mask_channels
713+
706714
assert h == img_size and w == img_size, f'height and width of image must be {img_size}'
715+
assert cond_img.shape[1] == img_channels, f'your input medical must have {img_channels} channels'
716+
assert img.shape[1] == mask_channels, f'the segmented image must have {mask_channels} channels'
717+
707718
times = torch.randint(0, self.num_timesteps, (b,), device = device).long()
708719

709720
img = normalize_to_neg_one_to_one(img)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.1.1',
6+
version = '0.1.2',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)