88
99import torch
1010from torch import nn , einsum
11+ from torch .nn import Module , ModuleList
1112import torch .nn .functional as F
1213from torch .fft import fft2 , ifft2
1314
@@ -43,7 +44,7 @@ def unnormalize_to_zero_to_one(t):
4344
4445# small helper modules
4546
46- class Residual (nn . Module ):
47+ class Residual (Module ):
4748 def __init__ (self , fn ):
4849 super ().__init__ ()
4950 self .fn = fn
@@ -63,7 +64,7 @@ def Downsample(dim, dim_out = None):
6364 nn .Conv2d (dim * 4 , default (dim_out , dim ), 1 )
6465 )
6566
66- class LayerNorm (nn . Module ):
67+ class LayerNorm (Module ):
6768 def __init__ (self , dim , bias = False ):
6869 super ().__init__ ()
6970 self .g = nn .Parameter (torch .ones (1 , dim , 1 , 1 ))
@@ -75,7 +76,7 @@ def forward(self, x):
7576 mean = torch .mean (x , dim = 1 , keepdim = True )
7677 return (x - mean ) * (var + eps ).rsqrt () * self .g + default (self .b , 0 )
7778
78- class SinusoidalPosEmb (nn . Module ):
79+ class SinusoidalPosEmb (Module ):
7980 def __init__ (self , dim ):
8081 super ().__init__ ()
8182 self .dim = dim
@@ -91,7 +92,7 @@ def forward(self, x):
9192
9293# building block modules
9394
94- class Block (nn . Module ):
95+ class Block (Module ):
9596 def __init__ (self , dim , dim_out , groups = 8 ):
9697 super ().__init__ ()
9798 self .proj = nn .Conv2d (dim , dim_out , 3 , padding = 1 )
@@ -109,7 +110,7 @@ def forward(self, x, scale_shift = None):
109110 x = self .act (x )
110111 return x
111112
112- class ResnetBlock (nn . Module ):
113+ class ResnetBlock (Module ):
113114 def __init__ (self , dim , dim_out , * , time_emb_dim = None , groups = 8 ):
114115 super ().__init__ ()
115116 self .mlp = nn .Sequential (
@@ -144,7 +145,7 @@ def FeedForward(dim, mult = 4):
144145 nn .Conv2d (inner_dim , dim , 1 ),
145146 )
146147
147- class LinearAttention (nn . Module ):
148+ class LinearAttention (Module ):
148149 def __init__ (self , dim , heads = 4 , dim_head = 32 ):
149150 super ().__init__ ()
150151 self .scale = dim_head ** - 0.5
@@ -178,7 +179,7 @@ def forward(self, x):
178179 out = rearrange (out , 'b h c (x y) -> b (h c) x y' , h = self .heads , x = h , y = w )
179180 return self .to_out (out )
180181
181- class Attention (nn . Module ):
182+ class Attention (Module ):
182183 def __init__ (self , dim , heads = 4 , dim_head = 32 ):
183184 super ().__init__ ()
184185 self .scale = dim_head ** - 0.5
@@ -206,7 +207,7 @@ def forward(self, x):
206207 out = rearrange (out , 'b h (x y) d -> b (h d) x y' , x = h , y = w )
207208 return self .to_out (out )
208209
209- class Transformer (nn . Module ):
210+ class Transformer (Module ):
210211 def __init__ (
211212 self ,
212213 dim ,
@@ -215,9 +216,9 @@ def __init__(
215216 depth = 1
216217 ):
217218 super ().__init__ ()
218- self .layers = nn . ModuleList ([])
219+ self .layers = ModuleList ([])
219220 for _ in range (depth ):
220- self .layers .append (nn . ModuleList ([
221+ self .layers .append (ModuleList ([
221222 Residual (Attention (dim , dim_head = dim_head , heads = heads )),
222223 Residual (FeedForward (dim ))
223224 ]))
@@ -228,25 +229,101 @@ def forward(self, x):
228229 x = ff (x )
229230 return x
230231
232+ # vision transformer for dynamic ff-parser
233+
234+ class ViT (Module ):
235+ def __init__ (
236+ self ,
237+ dim ,
238+ * ,
239+ patch_size ,
240+ channels = 3 ,
241+ channels_out = None ,
242+ dim_head = 32 ,
243+ heads = 4 ,
244+ depth = 4 ,
245+ ):
246+ super ().__init__ ()
247+ channels_out = default (channels_out , channels )
248+
249+ patch_dim = channels * (patch_size ** 2 )
250+ output_patch_dim = channels_out * (patch_size ** 2 )
251+
252+ self .to_tokens = nn .Sequential (
253+ Rearrange ('b c (h p1) (w p2) -> b (c p1 p2) h w' , p1 = patch_size , p2 = patch_size ),
254+ nn .Conv2d (patch_dim , dim , 1 ),
255+ LayerNorm (dim )
256+ )
257+
258+ self .transformer = Transformer (
259+ dim = dim ,
260+ dim_head = dim_head ,
261+ depth = depth
262+ )
263+
264+ self .to_patches = nn .Sequential (
265+ LayerNorm (dim ),
266+ nn .Conv2d (dim , output_patch_dim , 1 ),
267+ Rearrange ('b (c p1 p2) h w -> b c (h p1) (w p2)' , p1 = patch_size , p2 = patch_size ),
268+ )
269+
270+ nn .init .zeros_ (self .to_patches [- 2 ].weight )
271+ nn .init .zeros_ (self .to_patches [- 2 ].bias )
272+
273+ def forward (self , x ):
274+ x = self .to_tokens (x )
275+ x = self .transformer (x )
276+ return self .to_patches (x )
277+
231278# conditioning class
232279
233- class Conditioning (nn .Module ):
234- def __init__ (self , fmap_size , dim ):
280+ class Conditioning (Module ):
281+ def __init__ (
282+ self ,
283+ fmap_size ,
284+ dim ,
285+ dynamic = True ,
286+ dim_head = 32 ,
287+ heads = 4 ,
288+ depth = 4 ,
289+ patch_size = 16
290+ ):
235291 super ().__init__ ()
236292 self .ff_parser_attn_map = nn .Parameter (torch .ones (dim , fmap_size , fmap_size ))
237293
294+ self .dynamic = dynamic
295+
296+ if dynamic :
297+ self .to_dynamic_ff_parser_attn_map = ViT (
298+ dim = dim ,
299+ channels = dim * 2 ,
300+ channels_out = dim ,
301+ patch_size = patch_size ,
302+ heads = heads ,
303+ dim_head = dim_head
304+ )
305+
238306 self .norm_input = LayerNorm (dim , bias = True )
239307 self .norm_condition = LayerNorm (dim , bias = True )
240308
241309 self .block = ResnetBlock (dim , dim )
242310
243311 def forward (self , x , c ):
312+ ff_parser_attn_map = self .ff_parser_attn_map
244313
245314 # ff-parser in the paper, for modulating out the high frequencies
246315
247316 dtype = x .dtype
248317 x = fft2 (x )
249- x = x * self .ff_parser_attn_map
318+
319+ if self .dynamic :
320+ x_real = torch .view_as_real (x )
321+ x_real = rearrange (x_real , 'b d h w ri -> b (d ri) h w' )
322+ dynamic_ff_parser_attn_map = self .to_dynamic_ff_parser_attn_map (x_real )
323+ ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
324+
325+ x = x * ff_parser_attn_map
326+
250327 x = ifft2 (x ).real
251328 x = x .type (dtype )
252329
@@ -264,7 +341,7 @@ def forward(self, x, c):
264341# model
265342
266343@beartype
267- class Unet (nn . Module ):
344+ class Unet (Module ):
268345 def __init__ (
269346 self ,
270347 dim ,
@@ -281,7 +358,14 @@ def __init__(
281358 self_condition = False ,
282359 resnet_block_groups = 8 ,
283360 conditioning_klass = Conditioning ,
284- skip_connect_condition_fmaps = False # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
361+ skip_connect_condition_fmaps = False , # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
362+ dynamic_ff_parser_attn_map = False , # allow for ff-parser to be dynamic based on the input. will exclude condition for now
363+ conditioning_kwargs : dict = dict (
364+ dim_head = 32 ,
365+ heads = 4 ,
366+ depth = 4 ,
367+ patch_size = 16
368+ )
285369 ):
286370 super ().__init__ ()
287371
@@ -323,18 +407,27 @@ def __init__(
323407 heads = attn_heads
324408 )
325409
410+ # conditioner settings
411+
412+ if conditioning_klass == Conditioning :
413+ conditioning_klass = partial (
414+ Conditioning ,
415+ dynamic = dynamic_ff_parser_attn_map ,
416+ ** conditioning_kwargs
417+ )
418+
326419 # layers
327420
328421 num_resolutions = len (in_out )
329422 assert len (full_self_attn ) == num_resolutions
330423
331- self .conditioners = nn . ModuleList ([])
424+ self .conditioners = ModuleList ([])
332425
333426 self .skip_connect_condition_fmaps = skip_connect_condition_fmaps
334427
335428 # downsampling encoding blocks
336429
337- self .downs = nn . ModuleList ([])
430+ self .downs = ModuleList ([])
338431
339432 curr_fmap_size = image_size
340433
@@ -345,7 +438,7 @@ def __init__(
345438 self .conditioners .append (conditioning_klass (curr_fmap_size , dim_in ))
346439
347440
348- self .downs .append (nn . ModuleList ([
441+ self .downs .append (ModuleList ([
349442 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
350443 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
351444 Residual (attn_klass (dim_in , ** attn_kwargs )),
@@ -369,15 +462,15 @@ def __init__(
369462
370463 # upsampling decoding blocks
371464
372- self .ups = nn . ModuleList ([])
465+ self .ups = ModuleList ([])
373466
374467 for ind , ((dim_in , dim_out ), full_attn ) in enumerate (zip (reversed (in_out ), reversed (full_self_attn ))):
375468 is_last = ind == (len (in_out ) - 1 )
376469 attn_klass = Attention if full_attn else LinearAttention
377470
378471 skip_connect_dim = dim_in * (2 if self .skip_connect_condition_fmaps else 1 )
379472
380- self .ups .append (nn . ModuleList ([
473+ self .ups .append (ModuleList ([
381474 block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
382475 block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
383476 Residual (attn_klass (dim_out , ** attn_kwargs )),
@@ -481,7 +574,7 @@ def cosine_beta_schedule(timesteps, s = 0.008):
481574 betas = 1 - (alphas_cumprod [1 :] / alphas_cumprod [:- 1 ])
482575 return torch .clip (betas , 0 , 0.999 )
483576
484- class MedSegDiff (nn . Module ):
577+ class MedSegDiff (Module ):
485578 def __init__ (
486579 self ,
487580 model ,
0 commit comments