99import torch .nn .functional as F
1010from torch .fft import fft2 , ifft2
1111
12- from einops import rearrange , reduce
12+ from einops import rearrange , reduce , pack
1313from einops .layers .torch import Rearrange
1414
1515from tqdm .auto import tqdm
@@ -235,7 +235,8 @@ def __init__(
235235 channels = 3 ,
236236 self_condition = False ,
237237 resnet_block_groups = 8 ,
238- conditioning_klass = Conditioning
238+ conditioning_klass = Conditioning ,
239+ skip_connect_condition_fmaps = False # whether to concatenate the conditioning fmaps in the latter decoder upsampling portion of unet
239240 ):
240241 super ().__init__ ()
241242
@@ -274,6 +275,8 @@ def __init__(
274275
275276 self .conditioners = nn .ModuleList ([])
276277
278+ self .skip_connect_condition_fmaps = skip_connect_condition_fmaps
279+
277280 # downsampling encoding blocks
278281
279282 self .downs = nn .ModuleList ([])
@@ -314,9 +317,11 @@ def __init__(
314317 for ind , (dim_in , dim_out ) in enumerate (reversed (in_out )):
315318 is_last = ind == (len (in_out ) - 1 )
316319
320+ skip_connect_dim = dim_in * (2 if self .skip_connect_condition_fmaps else 1 )
321+
317322 self .ups .append (nn .ModuleList ([
318- block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
319- block_klass (dim_out + dim_in , dim_out , time_emb_dim = time_dim ),
323+ block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
324+ block_klass (dim_out + skip_connect_dim , dim_out , time_emb_dim = time_dim ),
320325 Residual (LinearAttention (dim_out )),
321326 Upsample (dim_out , dim_in ) if not is_last else nn .Conv2d (dim_out , dim_in , 3 , padding = 1 )
322327 ]))
@@ -333,7 +338,7 @@ def forward(
333338 cond ,
334339 x_self_cond = None
335340 ):
336- dtype = x .dtype
341+ dtype , skip_connect_c = x .dtype , self . skip_connect_condition_fmaps
337342
338343 if self .self_condition :
339344 x_self_cond = default (x_self_cond , lambda : torch .zeros_like (x ))
@@ -352,7 +357,7 @@ def forward(
352357 x = block1 (x , t )
353358 c = cond_block1 (c , t )
354359
355- h .append (x )
360+ h .append ([ x , c ] if skip_connect_c else [ x ] )
356361
357362 x = block2 (x , t )
358363 c = cond_block2 (c , t )
@@ -365,7 +370,7 @@ def forward(
365370
366371 c = conditioner (x , c )
367372
368- h .append (x )
373+ h .append ([ x , c ] if skip_connect_c else [ x ] )
369374
370375 x = downsample (x )
371376 c = cond_downsample (c )
@@ -379,10 +384,10 @@ def forward(
379384 x = self .mid_block2 (x , t )
380385
381386 for block1 , block2 , attn , upsample in self .ups :
382- x = torch .cat ((x , h .pop ()), dim = 1 )
387+ x = torch .cat ((x , * h .pop ()), dim = 1 )
383388 x = block1 (x , t )
384389
385- x = torch .cat ((x , h .pop ()), dim = 1 )
390+ x = torch .cat ((x , * h .pop ()), dim = 1 )
386391 x = block2 (x , t )
387392 x = attn (x )
388393
0 commit comments