Skip to content

Commit ddd6702

Browse files
committed
attend to the conditioning image as well for dynamic ff parser attn map
1 parent 11dc571 commit ddd6702

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,9 @@ def __init__(
307307
if dynamic:
308308
self.to_dynamic_ff_parser_attn_map = ViT(
309309
dim = dim,
310-
channels = dim * 2,
310+
channels = dim * 2 * 2, # both input and condition, and account for complex (real and imag components)
311311
channels_out = dim,
312+
image_size = image_size,
312313
patch_size = patch_size,
313314
heads = heads,
314315
dim_head = dim_head
@@ -328,9 +329,14 @@ def forward(self, x, c):
328329
x = fft2(x)
329330

330331
if self.dynamic:
331-
x_real = torch.view_as_real(x)
332-
x_real = rearrange(x_real, 'b d h w ri -> b (d ri) h w')
333-
dynamic_ff_parser_attn_map = self.to_dynamic_ff_parser_attn_map(x_real)
332+
c_complex = fft2(c)
333+
x_as_real, c_as_real = map(torch.view_as_real, (x, c_complex))
334+
x_as_real, c_as_real = map(lambda t: rearrange(t, 'b d h w ri -> b (d ri) h w'), (x_as_real, c_as_real))
335+
336+
to_dynamic_input = torch.cat((x_as_real, c_as_real), dim = 1)
337+
338+
dynamic_ff_parser_attn_map = self.to_dynamic_ff_parser_attn_map(to_dynamic_input)
339+
334340
ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
335341

336342
x = x * ff_parser_attn_map
@@ -423,7 +429,6 @@ def __init__(
423429
if conditioning_klass == Conditioning:
424430
conditioning_klass = partial(
425431
Conditioning,
426-
image_size = image_size,
427432
dynamic = dynamic_ff_parser_attn_map,
428433
**conditioning_kwargs
429434
)
@@ -447,8 +452,7 @@ def __init__(
447452
is_last = ind >= (num_resolutions - 1)
448453
attn_klass = Attention if full_attn else LinearAttention
449454

450-
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in))
451-
455+
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in, image_size = curr_fmap_size))
452456

453457
self.downs.append(ModuleList([
454458
block_klass(dim_in, dim_in, time_emb_dim = time_dim),

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.3.1',
6+
version = '0.3.2',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)