Skip to content

Commit 1e7e9f5

Browse files
committed
move condition logic to another class, to prepare for alternatives
1 parent 0226d77 commit 1e7e9f5

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
Implementation of <a href="https://arxiv.org/abs/2211.00611">MedSegDiff</a> in Pytorch - SOTA medical segmentation out of Baidu using DDPM and enhanced conditioning on the feature level, with filtering of features in fourier space.
66

7-
I will also add attention and introduce an extended type of cross modulation on the attention matrices, alphafold2 style.
8-
97
## Install
108

119
```bash
@@ -47,8 +45,6 @@ pred.shape # predicted segmented images - (8, 3, 12
4745

4846
## Todo
4947

50-
- [ ] add a cross attention variant for generating the attentive map (A)
51-
- [ ] modulate attention matrices in middle and other self attention layers, wherever full attention is used
5248
- [ ] some basic training code, with Trainer taking in custom dataset tailored for medical image formats
5349

5450
## Citations

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,19 @@ def forward(self, x):
196196
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
197197
return self.to_out(out)
198198

199+
# conditioning class
200+
201+
class FourierConditioning(nn.Module):
202+
def __init__(self, dim):
203+
super().__init__()
204+
self.norm_input = LayerNorm(dim, bias = True)
205+
self.norm_condition = LayerNorm(dim, bias = True)
206+
207+
def forward(self, x, c):
208+
normed_x = self.norm_input(x)
209+
normed_c = self.norm_condition(c)
210+
return (normed_x * normed_c) * c # eq 3 in paper
211+
199212
# model
200213

201214
class Unet(nn.Module):
@@ -207,7 +220,8 @@ def __init__(
207220
dim_mults=(1, 2, 4, 8),
208221
channels = 3,
209222
self_condition = False,
210-
resnet_block_groups = 8
223+
resnet_block_groups = 8,
224+
conditioning_klass = FourierConditioning
211225
):
212226
super().__init__()
213227

@@ -242,18 +256,21 @@ def __init__(
242256

243257
num_resolutions = len(in_out)
244258

259+
self.conditioners = nn.ModuleList([])
260+
245261
# downsampling encoding blocks
246262

247263
self.downs = nn.ModuleList([])
248264

249265
for ind, (dim_in, dim_out) in enumerate(in_out):
250266
is_last = ind >= (num_resolutions - 1)
251267

268+
self.conditioners.append(conditioning_klass(dim_in))
269+
252270
self.downs.append(nn.ModuleList([
253271
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
254272
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
255273
Residual(LinearAttention(dim_in)),
256-
LayerNorm(dim_in, bias = True),
257274
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
258275
]))
259276

@@ -310,7 +327,7 @@ def forward(
310327

311328
h = []
312329

313-
for (block1, block2, attn, norm, downsample), (cond_block1, cond_block2, cond_attn, cond_norm, cond_downsample) in zip(self.downs, self.cond_downs):
330+
for (block1, block2, attn, downsample), (cond_block1, cond_block2, cond_attn, cond_downsample), conditioner in zip(self.downs, self.cond_downs, self.conditioners):
314331
x = block1(x, t)
315332
c = cond_block1(c, t)
316333

@@ -322,22 +339,10 @@ def forward(
322339
x = attn(x)
323340
c = cond_attn(c)
324341

325-
# they create an attentive map A by element-wise multiplication of
326-
# then they use it to modulate it to modulate the condition in fourier space (ff-parse)
327-
# eq. 3 in the paper
328-
329-
A = norm(x) * cond_norm(c) * c
330-
331-
# fc stands for conditioning in fourier space
332-
333-
fc = fft2(c)
334-
335-
fc = fc * A # eq. 5 in paper
336-
337-
c = ifft2(fc).real
338-
c = c.type(dtype)
342+
# condition using modulation of fourier frequencies with attentive map
343+
# you can test your own conditioners by passing in a different conditioner_klass , if you believe you can best the paper
339344

340-
# </conditioning>
345+
c = conditioner(x, c)
341346

342347
h.append(x)
343348

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)