@@ -627,7 +627,7 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso
627627        batch_size , num_channels , height , width  =  x .shape 
628628        latent_height  =  height  //  self .spatial_compression_ratio 
629629        latent_width  =  width  //  self .spatial_compression_ratio 
630-          
630+ 
631631        tile_latent_min_height  =  self .tile_sample_min_height  //  self .spatial_compression_ratio 
632632        tile_latent_min_width  =  self .tile_sample_min_width  //  self .spatial_compression_ratio 
633633        tile_latent_stride_height  =  self .tile_sample_stride_height  //  self .spatial_compression_ratio 
@@ -642,7 +642,10 @@ def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tenso
642642            row  =  []
643643            for  j  in  range (0 , x .shape [3 ], self .tile_sample_stride_width ):
644644                tile  =  x [:, :, i  : i  +  self .tile_sample_min_height , j  : j  +  self .tile_sample_min_width ]
645-                 if  tile .shape [2 ] %  self .spatial_compression_ratio  !=  0  or  tile .shape [3 ] %  self .spatial_compression_ratio  !=  0 :
645+                 if  (
646+                     tile .shape [2 ] %  self .spatial_compression_ratio  !=  0 
647+                     or  tile .shape [3 ] %  self .spatial_compression_ratio  !=  0 
648+                 ):
646649                    pad_h  =  (self .spatial_compression_ratio  -  tile .shape [2 ]) %  self .spatial_compression_ratio 
647650                    pad_w  =  (self .spatial_compression_ratio  -  tile .shape [3 ]) %  self .spatial_compression_ratio 
648651                    tile  =  F .pad (tile , (0 , pad_w , 0 , pad_h ))
0 commit comments