Skip to content

Commit d06750a

Browse files
authored
Fix autoencoder_kl_wan.py bugs for Wan2.2 VAE (huggingface#12335)
* Update autoencoder_kl_wan.py When using the Wan2.2 VAE, the spatial compression ratio calculated here is incorrect. It should be 16 instead of 8. Pass it in directly via the config to ensure it’s correct here. * Update autoencoder_kl_wan.py
1 parent 8c72cd1 commit d06750a

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,7 +1052,7 @@ def __init__(
10521052
is_residual=is_residual,
10531053
)
10541054

1055-
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1055+
self.spatial_compression_ratio = scale_factor_spatial
10561056

10571057
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
10581058
# to perform decoding of a single video latent at a time.
@@ -1145,12 +1145,13 @@ def clear_cache(self):
11451145
def _encode(self, x: torch.Tensor):
11461146
_, _, num_frame, height, width = x.shape
11471147

1148-
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1149-
return self.tiled_encode(x)
1150-
11511148
self.clear_cache()
11521149
if self.config.patch_size is not None:
11531150
x = patchify(x, patch_size=self.config.patch_size)
1151+
1152+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1153+
return self.tiled_encode(x)
1154+
11541155
iter_ = 1 + (num_frame - 1) // 4
11551156
for i in range(iter_):
11561157
self._enc_conv_idx = [0]

0 commit comments

Comments
 (0)