Skip to content

Commit c9c616f

Browse files
committed
1 parent 6582570 commit c9c616f

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,17 +1278,15 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12781278
`torch.Tensor`:
12791279
The latent representation of the encoded videos.
12801280
"""
1281-
if self.config.patch_size is not None:
1282-
x = patchify(x, patch_size=self.config.patch_size)
1283-
1281+
spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size # remove compression_ratio by patchify
12841282
_, _, num_frames, height, width = x.shape
1285-
latent_height = height // self.spatial_compression_ratio
1286-
latent_width = width // self.spatial_compression_ratio
1283+
latent_height = height // spatial_compression_ratio
1284+
latent_width = width // spatial_compression_ratio
12871285

1288-
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1289-
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1290-
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1291-
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1286+
tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
1287+
tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
1288+
tile_latent_stride_height = self.tile_sample_stride_height // spatial_compression_ratio
1289+
tile_latent_stride_width = self.tile_sample_stride_width // spatial_compression_ratio
12921290

12931291
blend_height = tile_latent_min_height - tile_latent_stride_height
12941292
blend_width = tile_latent_min_width - tile_latent_stride_width
@@ -1353,15 +1351,16 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13531351
returned.
13541352
"""
13551353
z = self.post_quant_conv(z)
1356-
1354+
1355+
spatial_compression_ratio = self.spatial_compression_ratio // self.config.patch_size # remove compression_ratio by patchify
13571356
_, _, num_frames, height, width = z.shape
1358-
sample_height = height * self.spatial_compression_ratio
1359-
sample_width = width * self.spatial_compression_ratio
1357+
sample_height = height * spatial_compression_ratio
1358+
sample_width = width * spatial_compression_ratio
13601359

1361-
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1362-
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1363-
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1364-
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1360+
tile_latent_min_height = self.tile_sample_min_height // spatial_compression_ratio
1361+
tile_latent_min_width = self.tile_sample_min_width // spatial_compression_ratio
1362+
tile_latent_stride_height = self.tile_sample_stride_height // spatial_compression_ratio
1363+
tile_latent_stride_width = self.tile_sample_stride_width // spatial_compression_ratio
13651364

13661365
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
13671366
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width

0 commit comments

Comments
 (0)