@@ -382,6 +382,20 @@ def forward(
382382                If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is 
383383                returned, otherwise a `tuple` is returned where the first element is the sample tensor. 
384384        """ 
385+         # By default samples have to be AT least a multiple of the overall upsampling factor. 
386+         # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 
387+         # However, the upsampling interpolation output size can be forced to fit any upsampling size 
388+         # on the fly if necessary. 
389+         default_overall_up_factor  =  2 ** self .num_upsamplers 
390+ 
391+         # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 
392+         forward_upsample_size  =  False 
393+         upsample_size  =  None 
394+ 
395+         if  any (s  %  default_overall_up_factor  !=  0  for  s  in  sample .shape [- 2 :]):
396+             logger .info ("Forward upsample size to force interpolation output size." )
397+             forward_upsample_size  =  True 
398+ 
385399        # 1. time 
386400        timesteps  =  timestep 
387401        if  not  torch .is_tensor (timesteps ):
@@ -457,22 +471,31 @@ def forward(
457471
458472        # 5. up 
459473        for  i , upsample_block  in  enumerate (self .up_blocks ):
474+             is_final_block  =  i  ==  len (self .up_blocks ) -  1 
475+ 
460476            res_samples  =  down_block_res_samples [- len (upsample_block .resnets ) :]
461477            down_block_res_samples  =  down_block_res_samples [: - len (upsample_block .resnets )]
462478
479+             # if we have not reached the final block and need to forward the 
480+             # upsample size, we do it here 
481+             if  not  is_final_block  and  forward_upsample_size :
482+                 upsample_size  =  down_block_res_samples [- 1 ].shape [2 :]
483+ 
463484            if  hasattr (upsample_block , "has_cross_attention" ) and  upsample_block .has_cross_attention :
464485                sample  =  upsample_block (
465486                    hidden_states = sample ,
466487                    temb = emb ,
467488                    res_hidden_states_tuple = res_samples ,
468489                    encoder_hidden_states = encoder_hidden_states ,
490+                     upsample_size = upsample_size ,
469491                    image_only_indicator = image_only_indicator ,
470492                )
471493            else :
472494                sample  =  upsample_block (
473495                    hidden_states = sample ,
474496                    temb = emb ,
475497                    res_hidden_states_tuple = res_samples ,
498+                     upsample_size = upsample_size ,
476499                    image_only_indicator = image_only_indicator ,
477500                )
478501
0 commit comments