@@ -1278,17 +1278,15 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12781278 `torch.Tensor`:
12791279 The latent representation of the encoded videos.
12801280 """
1281- if self .config .patch_size is not None :
1282- x = patchify (x , patch_size = self .config .patch_size )
1283-
1281+ spatial_compression_ratio = self .spatial_compression_ratio // self .config .patch_size # remove compression_ratio by patchify
12841282 _ , _ , num_frames , height , width = x .shape
1285- latent_height = height // self . spatial_compression_ratio
1286- latent_width = width // self . spatial_compression_ratio
1283+ latent_height = height // spatial_compression_ratio
1284+ latent_width = width // spatial_compression_ratio
12871285
1288- tile_latent_min_height = self .tile_sample_min_height // self . spatial_compression_ratio
1289- tile_latent_min_width = self .tile_sample_min_width // self . spatial_compression_ratio
1290- tile_latent_stride_height = self .tile_sample_stride_height // self . spatial_compression_ratio
1291- tile_latent_stride_width = self .tile_sample_stride_width // self . spatial_compression_ratio
1286+ tile_latent_min_height = self .tile_sample_min_height // spatial_compression_ratio
1287+ tile_latent_min_width = self .tile_sample_min_width // spatial_compression_ratio
1288+ tile_latent_stride_height = self .tile_sample_stride_height // spatial_compression_ratio
1289+ tile_latent_stride_width = self .tile_sample_stride_width // spatial_compression_ratio
12921290
12931291 blend_height = tile_latent_min_height - tile_latent_stride_height
12941292 blend_width = tile_latent_min_width - tile_latent_stride_width
@@ -1353,15 +1351,16 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13531351 returned.
13541352 """
13551353 z = self .post_quant_conv (z )
1356-
1354+
1355+ spatial_compression_ratio = self .spatial_compression_ratio // self .config .patch_size # remove compression_ratio by patchify
13571356 _ , _ , num_frames , height , width = z .shape
1358- sample_height = height * self . spatial_compression_ratio
1359- sample_width = width * self . spatial_compression_ratio
1357+ sample_height = height * spatial_compression_ratio
1358+ sample_width = width * spatial_compression_ratio
13601359
1361- tile_latent_min_height = self .tile_sample_min_height // self . spatial_compression_ratio
1362- tile_latent_min_width = self .tile_sample_min_width // self . spatial_compression_ratio
1363- tile_latent_stride_height = self .tile_sample_stride_height // self . spatial_compression_ratio
1364- tile_latent_stride_width = self .tile_sample_stride_width // self . spatial_compression_ratio
1360+ tile_latent_min_height = self .tile_sample_min_height // spatial_compression_ratio
1361+ tile_latent_min_width = self .tile_sample_min_width // spatial_compression_ratio
1362+ tile_latent_stride_height = self .tile_sample_stride_height // spatial_compression_ratio
1363+ tile_latent_stride_width = self .tile_sample_stride_width // spatial_compression_ratio
13651364
13661365 blend_height = self .tile_sample_min_height - self .tile_sample_stride_height
13671366 blend_width = self .tile_sample_min_width - self .tile_sample_stride_width
0 commit comments