Skip to content

Commit 453b495

Browse files
committed
fix all misunderstanding about the paper, hopefully
1 parent 1e7e9f5 commit 453b495

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@ from med_seg_diff_pytorch import Unet, MedSegDiff
1818

1919
model = Unet(
2020
dim = 64,
21+
image_size = 128,
2122
dim_mults = (1, 2, 4, 8)
2223
)
2324

2425
diffusion = MedSegDiff(
2526
model,
26-
image_size = 128,
2727
timesteps = 1000
2828
).cuda()
2929

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,33 +198,49 @@ def forward(self, x):
198198

199199
# conditioning class
200200

201-
class FourierConditioning(nn.Module):
202-
def __init__(self, dim):
201+
class Conditioning(nn.Module):
202+
def __init__(self, fmap_size, dim):
203203
super().__init__()
204+
self.ff_parser_attn_map = nn.Parameter(torch.ones(dim, fmap_size, fmap_size))
205+
204206
self.norm_input = LayerNorm(dim, bias = True)
205207
self.norm_condition = LayerNorm(dim, bias = True)
206208

207209
def forward(self, x, c):
210+
211+
# ff-parser in the paper, for modulating out the high frequencies
212+
213+
dtype = x.dtype
214+
x = fft2(x)
215+
x = x * self.ff_parser_attn_map
216+
x = ifft2(x).real
217+
x = x.type(dtype)
218+
219+
# eq 3 in paper
220+
208221
normed_x = self.norm_input(x)
209222
normed_c = self.norm_condition(c)
210-
return (normed_x * normed_c) * c # eq 3 in paper
223+
return (normed_x * normed_c) * c
211224

212225
# model
213226

214227
class Unet(nn.Module):
215228
def __init__(
216229
self,
217230
dim,
231+
image_size,
218232
init_dim = None,
219233
out_dim = None,
220234
dim_mults=(1, 2, 4, 8),
221235
channels = 3,
222236
self_condition = False,
223237
resnet_block_groups = 8,
224-
conditioning_klass = FourierConditioning
238+
conditioning_klass = Conditioning
225239
):
226240
super().__init__()
227241

242+
self.image_size = image_size
243+
228244
# determine dimensions
229245

230246
self.channels = channels
@@ -262,10 +278,12 @@ def __init__(
262278

263279
self.downs = nn.ModuleList([])
264280

281+
curr_fmap_size = image_size
282+
265283
for ind, (dim_in, dim_out) in enumerate(in_out):
266284
is_last = ind >= (num_resolutions - 1)
267285

268-
self.conditioners.append(conditioning_klass(dim_in))
286+
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in))
269287

270288
self.downs.append(nn.ModuleList([
271289
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
@@ -274,6 +292,9 @@ def __init__(
274292
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
275293
]))
276294

295+
if not is_last:
296+
curr_fmap_size //= 2
297+
277298
# middle blocks
278299

279300
mid_dim = dims[-1]
@@ -402,7 +423,6 @@ def __init__(
402423
self,
403424
model,
404425
*,
405-
image_size,
406426
timesteps = 1000,
407427
sampling_timesteps = None,
408428
objective = 'pred_noise',
@@ -415,7 +435,7 @@ def __init__(
415435
self.channels = self.model.channels
416436
self.self_condition = self.model.self_condition
417437

418-
self.image_size = image_size
438+
self.image_size = model.image_size
419439

420440
self.objective = objective
421441

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.3',
6+
version = '0.0.4',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)