@@ -307,8 +307,9 @@ def __init__(
307307 if dynamic :
308308 self .to_dynamic_ff_parser_attn_map = ViT (
309309 dim = dim ,
310- channels = dim * 2 ,
310+ channels = dim * 2 * 2 , # both input and condition, and account for complex (real and imag components)
311311 channels_out = dim ,
312+ image_size = image_size ,
312313 patch_size = patch_size ,
313314 heads = heads ,
314315 dim_head = dim_head
@@ -328,9 +329,14 @@ def forward(self, x, c):
328329 x = fft2 (x )
329330
330331 if self .dynamic :
331- x_real = torch .view_as_real (x )
332- x_real = rearrange (x_real , 'b d h w ri -> b (d ri) h w' )
333- dynamic_ff_parser_attn_map = self .to_dynamic_ff_parser_attn_map (x_real )
332+ c_complex = fft2 (c )
333+ x_as_real , c_as_real = map (torch .view_as_real , (x , c_complex ))
334+ x_as_real , c_as_real = map (lambda t : rearrange (t , 'b d h w ri -> b (d ri) h w' ), (x_as_real , c_as_real ))
335+
336+ to_dynamic_input = torch .cat ((x_as_real , c_as_real ), dim = 1 )
337+
338+ dynamic_ff_parser_attn_map = self .to_dynamic_ff_parser_attn_map (to_dynamic_input )
339+
334340 ff_parser_attn_map = ff_parser_attn_map + dynamic_ff_parser_attn_map
335341
336342 x = x * ff_parser_attn_map
@@ -423,7 +429,6 @@ def __init__(
423429 if conditioning_klass == Conditioning :
424430 conditioning_klass = partial (
425431 Conditioning ,
426- image_size = image_size ,
427432 dynamic = dynamic_ff_parser_attn_map ,
428433 ** conditioning_kwargs
429434 )
@@ -447,8 +452,7 @@ def __init__(
447452 is_last = ind >= (num_resolutions - 1 )
448453 attn_klass = Attention if full_attn else LinearAttention
449454
450- self .conditioners .append (conditioning_klass (curr_fmap_size , dim_in ))
451-
455+ self .conditioners .append (conditioning_klass (curr_fmap_size , dim_in , image_size = curr_fmap_size ))
452456
453457 self .downs .append (ModuleList ([
454458 block_klass (dim_in , dim_in , time_emb_dim = time_dim ),
0 commit comments