Skip to content

Commit 0707009

Browse files
committed
also include feature maps from conditioning encoder in the decoder, if skip_connect_condition_fmaps is set to True
1 parent 453b495 commit 0707009

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn.functional as F
1010
from torch.fft import fft2, ifft2
1111

12-
from einops import rearrange, reduce
12+
from einops import rearrange, reduce, pack
1313
from einops.layers.torch import Rearrange
1414

1515
from tqdm.auto import tqdm
@@ -235,7 +235,8 @@ def __init__(
235235
channels = 3,
236236
self_condition = False,
237237
resnet_block_groups = 8,
238-
conditioning_klass = Conditioning
238+
conditioning_klass = Conditioning,
239+
skip_connect_condition_fmaps = False # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
239240
):
240241
super().__init__()
241242

@@ -274,6 +275,8 @@ def __init__(
274275

275276
self.conditioners = nn.ModuleList([])
276277

278+
self.skip_connect_condition_fmaps = skip_connect_condition_fmaps
279+
277280
# downsampling encoding blocks
278281

279282
self.downs = nn.ModuleList([])
@@ -314,9 +317,11 @@ def __init__(
314317
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
315318
is_last = ind == (len(in_out) - 1)
316319

320+
skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)
321+
317322
self.ups.append(nn.ModuleList([
318-
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
319-
block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
323+
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
324+
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
320325
Residual(LinearAttention(dim_out)),
321326
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
322327
]))
@@ -333,7 +338,7 @@ def forward(
333338
cond,
334339
x_self_cond = None
335340
):
336-
dtype = x.dtype
341+
dtype, skip_connect_c = x.dtype, self.skip_connect_condition_fmaps
337342

338343
if self.self_condition:
339344
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
@@ -352,7 +357,7 @@ def forward(
352357
x = block1(x, t)
353358
c = cond_block1(c, t)
354359

355-
h.append(x)
360+
h.append([x, c] if skip_connect_c else [x])
356361

357362
x = block2(x, t)
358363
c = cond_block2(c, t)
@@ -365,7 +370,7 @@ def forward(
365370

366371
c = conditioner(x, c)
367372

368-
h.append(x)
373+
h.append([x, c] if skip_connect_c else [x])
369374

370375
x = downsample(x)
371376
c = cond_downsample(c)
@@ -379,10 +384,10 @@ def forward(
379384
x = self.mid_block2(x, t)
380385

381386
for block1, block2, attn, upsample in self.ups:
382-
x = torch.cat((x, h.pop()), dim = 1)
387+
x = torch.cat((x, *h.pop()), dim = 1)
383388
x = block1(x, t)
384389

385-
x = torch.cat((x, h.pop()), dim = 1)
390+
x = torch.cat((x, *h.pop()), dim = 1)
386391
x = block2(x, t)
387392
x = attn(x)
388393

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.4',
6+
version = '0.0.5',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)