Skip to content

Commit f9447b0

Browse files
committed
add ability to modulate the ff parser attn map based on the input using a vision transformer
1 parent 512eb49 commit f9447b0

File tree

2 files changed

+115
-22
lines changed

2 files changed

+115
-22
lines changed

med_seg_diff_pytorch/med_seg_diff_pytorch.py

Lines changed: 114 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010
from torch import nn, einsum
11+
from torch.nn import Module, ModuleList
1112
import torch.nn.functional as F
1213
from torch.fft import fft2, ifft2
1314

@@ -43,7 +44,7 @@ def unnormalize_to_zero_to_one(t):
4344

4445
# small helper modules
4546

46-
class Residual(nn.Module):
47+
class Residual(Module):
4748
def __init__(self, fn):
4849
super().__init__()
4950
self.fn = fn
@@ -63,7 +64,7 @@ def Downsample(dim, dim_out = None):
6364
nn.Conv2d(dim * 4, default(dim_out, dim), 1)
6465
)
6566

66-
class LayerNorm(nn.Module):
67+
class LayerNorm(Module):
6768
def __init__(self, dim, bias = False):
6869
super().__init__()
6970
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
@@ -75,7 +76,7 @@ def forward(self, x):
7576
mean = torch.mean(x, dim = 1, keepdim = True)
7677
return (x - mean) * (var + eps).rsqrt() * self.g + default(self.b, 0)
7778

78-
class SinusoidalPosEmb(nn.Module):
79+
class SinusoidalPosEmb(Module):
7980
def __init__(self, dim):
8081
super().__init__()
8182
self.dim = dim
@@ -91,7 +92,7 @@ def forward(self, x):
9192

9293
# building block modules
9394

94-
class Block(nn.Module):
95+
class Block(Module):
9596
def __init__(self, dim, dim_out, groups = 8):
9697
super().__init__()
9798
self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
@@ -109,7 +110,7 @@ def forward(self, x, scale_shift = None):
109110
x = self.act(x)
110111
return x
111112

112-
class ResnetBlock(nn.Module):
113+
class ResnetBlock(Module):
113114
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
114115
super().__init__()
115116
self.mlp = nn.Sequential(
@@ -144,7 +145,7 @@ def FeedForward(dim, mult = 4):
144145
nn.Conv2d(inner_dim, dim, 1),
145146
)
146147

147-
class LinearAttention(nn.Module):
148+
class LinearAttention(Module):
148149
def __init__(self, dim, heads = 4, dim_head = 32):
149150
super().__init__()
150151
self.scale = dim_head ** -0.5
@@ -178,7 +179,7 @@ def forward(self, x):
178179
out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w)
179180
return self.to_out(out)
180181

181-
class Attention(nn.Module):
182+
class Attention(Module):
182183
def __init__(self, dim, heads = 4, dim_head = 32):
183184
super().__init__()
184185
self.scale = dim_head ** -0.5
@@ -206,7 +207,7 @@ def forward(self, x):
206207
out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)
207208
return self.to_out(out)
208209

209-
class Transformer(nn.Module):
210+
class Transformer(Module):
210211
def __init__(
211212
self,
212213
dim,
@@ -215,9 +216,9 @@ def __init__(
215216
depth = 1
216217
):
217218
super().__init__()
218-
self.layers = nn.ModuleList([])
219+
self.layers = ModuleList([])
219220
for _ in range(depth):
220-
self.layers.append(nn.ModuleList([
221+
self.layers.append(ModuleList([
221222
Residual(Attention(dim, dim_head = dim_head, heads = heads)),
222223
Residual(FeedForward(dim))
223224
]))
@@ -228,25 +229,101 @@ def forward(self, x):
228229
x = ff(x)
229230
return x
230231

232+
# vision transformer for dynamic ff-parser
233+
234+
class ViT(Module):
235+
def __init__(
236+
self,
237+
dim,
238+
*,
239+
patch_size,
240+
channels = 3,
241+
channels_out = None,
242+
dim_head = 32,
243+
heads = 4,
244+
depth = 4,
245+
):
246+
super().__init__()
247+
channels_out = default(channels_out, channels)
248+
249+
patch_dim = channels * (patch_size ** 2)
250+
output_patch_dim = channels_out * (patch_size ** 2)
251+
252+
self.to_tokens = nn.Sequential(
253+
Rearrange('b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = patch_size, p2 = patch_size),
254+
nn.Conv2d(patch_dim, dim, 1),
255+
LayerNorm(dim)
256+
)
257+
258+
self.transformer = Transformer(
259+
dim = dim,
260+
dim_head = dim_head,
261+
depth = depth
262+
)
263+
264+
self.to_patches = nn.Sequential(
265+
LayerNorm(dim),
266+
nn.Conv2d(dim, output_patch_dim, 1),
267+
Rearrange('b (c p1 p2) h w -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size),
268+
)
269+
270+
nn.init.zeros_(self.to_patches[-2].weight)
271+
nn.init.zeros_(self.to_patches[-2].bias)
272+
273+
def forward(self, x):
274+
x = self.to_tokens(x)
275+
x = self.transformer(x)
276+
return self.to_patches(x)
277+
231278
# conditioning class
232279

233-
class Conditioning(nn.Module):
234-
def __init__(self, fmap_size, dim):
280+
class Conditioning(Module):
281+
def __init__(
282+
self,
283+
fmap_size,
284+
dim,
285+
dynamic = True,
286+
dim_head = 32,
287+
heads = 4,
288+
depth = 4,
289+
patch_size = 16
290+
):
235291
super().__init__()
236292
self.ff_parser_attn_map = nn.Parameter(torch.ones(dim, fmap_size, fmap_size))
237293

294+
self.dynamic = dynamic
295+
296+
if dynamic:
297+
self.to_dynamic_ff_parser_attn_map = ViT(
298+
dim = dim,
299+
channels = dim * 2,
300+
channels_out = dim,
301+
patch_size = patch_size,
302+
heads = heads,
303+
dim_head = dim_head
304+
)
305+
238306
self.norm_input = LayerNorm(dim, bias = True)
239307
self.norm_condition = LayerNorm(dim, bias = True)
240308

241309
self.block = ResnetBlock(dim, dim)
242310

243311
def forward(self, x, c):
312+
ff_parser_attn_map = self.ff_parser_attn_map
244313

245314
# ff-parser in the paper, for modulating out the high frequencies
246315

247316
dtype = x.dtype
248317
x = fft2(x)
249-
x = x * self.ff_parser_attn_map
318+
319+
if self.dynamic:
320+
x_real = torch.view_as_real(x)
321+
x_real = rearrange(x_real, 'b d h w ri -> b (d ri) h w')
322+
dynamic_ff_parser_attn_map = self.to_dynamic_ff_parser_attn_map(x_real)
323+
ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
324+
325+
x = x * ff_parser_attn_map
326+
250327
x = ifft2(x).real
251328
x = x.type(dtype)
252329

@@ -264,7 +341,7 @@ def forward(self, x, c):
264341
# model
265342

266343
@beartype
267-
class Unet(nn.Module):
344+
class Unet(Module):
268345
def __init__(
269346
self,
270347
dim,
@@ -281,7 +358,14 @@ def __init__(
281358
self_condition = False,
282359
resnet_block_groups = 8,
283360
conditioning_klass = Conditioning,
284-
skip_connect_condition_fmaps = False # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
361+
skip_connect_condition_fmaps = False, # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
362+
dynamic_ff_parser_attn_map = False, # allow for ff-parser to be dynamic based on the input. will exclude condition for now
363+
conditioning_kwargs: dict = dict(
364+
dim_head = 32,
365+
heads = 4,
366+
depth = 4,
367+
patch_size = 16
368+
)
285369
):
286370
super().__init__()
287371

@@ -323,18 +407,27 @@ def __init__(
323407
heads = attn_heads
324408
)
325409

410+
# conditioner settings
411+
412+
if conditioning_klass == Conditioning:
413+
conditioning_klass = partial(
414+
Conditioning,
415+
dynamic = dynamic_ff_parser_attn_map,
416+
**conditioning_kwargs
417+
)
418+
326419
# layers
327420

328421
num_resolutions = len(in_out)
329422
assert len(full_self_attn) == num_resolutions
330423

331-
self.conditioners = nn.ModuleList([])
424+
self.conditioners = ModuleList([])
332425

333426
self.skip_connect_condition_fmaps = skip_connect_condition_fmaps
334427

335428
# downsampling encoding blocks
336429

337-
self.downs = nn.ModuleList([])
430+
self.downs = ModuleList([])
338431

339432
curr_fmap_size = image_size
340433

@@ -345,7 +438,7 @@ def __init__(
345438
self.conditioners.append(conditioning_klass(curr_fmap_size, dim_in))
346439

347440

348-
self.downs.append(nn.ModuleList([
441+
self.downs.append(ModuleList([
349442
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
350443
block_klass(dim_in, dim_in, time_emb_dim = time_dim),
351444
Residual(attn_klass(dim_in, **attn_kwargs)),
@@ -369,15 +462,15 @@ def __init__(
369462

370463
# upsampling decoding blocks
371464

372-
self.ups = nn.ModuleList([])
465+
self.ups = ModuleList([])
373466

374467
for ind, ((dim_in, dim_out), full_attn) in enumerate(zip(reversed(in_out), reversed(full_self_attn))):
375468
is_last = ind == (len(in_out) - 1)
376469
attn_klass = Attention if full_attn else LinearAttention
377470

378471
skip_connect_dim = dim_in * (2 if self.skip_connect_condition_fmaps else 1)
379472

380-
self.ups.append(nn.ModuleList([
473+
self.ups.append(ModuleList([
381474
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
382475
block_klass(dim_out + skip_connect_dim, dim_out, time_emb_dim = time_dim),
383476
Residual(attn_klass(dim_out, **attn_kwargs)),
@@ -481,7 +574,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
481574
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
482575
return torch.clip(betas, 0, 0.999)
483576

484-
class MedSegDiff(nn.Module):
577+
class MedSegDiff(Module):
485578
def __init__(
486579
self,
487580
model,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'med-seg-diff-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.2.6',
6+
version = '0.3.0',
77
license='MIT',
88
description = 'MedSegDiff - SOTA medical image segmentation - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)