@@ -381,6 +381,20 @@ def forward(
381381                If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise 
382382                a `tuple` is returned where the first element is the sample tensor. 
383383        """ 
384+         # By default samples have to be AT least a multiple of the overall upsampling factor. 
385+         # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 
386+         # However, the upsampling interpolation output size can be forced to fit any upsampling size 
387+         # on the fly if necessary. 
388+         default_overall_up_factor  =  2 ** self .num_upsamplers 
389+ 
390+         # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 
391+         forward_upsample_size  =  False 
392+         upsample_size  =  None 
393+ 
394+         if  any (s  %  default_overall_up_factor  !=  0  for  s  in  sample .shape [- 2 :]):
395+             logger .info ("Forward upsample size to force interpolation output size." )
396+             forward_upsample_size  =  True 
397+ 
384398        # 1. time 
385399        timesteps  =  timestep 
386400        if  not  torch .is_tensor (timesteps ):
@@ -456,22 +470,31 @@ def forward(
456470
457471        # 5. up 
458472        for  i , upsample_block  in  enumerate (self .up_blocks ):
473+             is_final_block  =  i  ==  len (self .up_blocks ) -  1 
474+ 
459475            res_samples  =  down_block_res_samples [- len (upsample_block .resnets ) :]
460476            down_block_res_samples  =  down_block_res_samples [: - len (upsample_block .resnets )]
461477
478+             # if we have not reached the final block and need to forward the 
479+             # upsample size, we do it here 
480+             if  not  is_final_block  and  forward_upsample_size :
481+                 upsample_size  =  down_block_res_samples [- 1 ].shape [2 :]
482+ 
462483            if  hasattr (upsample_block , "has_cross_attention" ) and  upsample_block .has_cross_attention :
463484                sample  =  upsample_block (
464485                    hidden_states = sample ,
465486                    temb = emb ,
466487                    res_hidden_states_tuple = res_samples ,
467488                    encoder_hidden_states = encoder_hidden_states ,
489+                     upsample_size = upsample_size ,
468490                    image_only_indicator = image_only_indicator ,
469491                )
470492            else :
471493                sample  =  upsample_block (
472494                    hidden_states = sample ,
473495                    temb = emb ,
474496                    res_hidden_states_tuple = res_samples ,
497+                     upsample_size = upsample_size ,
475498                    image_only_indicator = image_only_indicator ,
476499                )
477500
0 commit comments