Skip to content

Commit 798b492

Browse files
committed
fix wan vae tiling bug
1 parent 751e250 commit 798b492

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,9 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
12771277
`torch.Tensor`:
12781278
The latent representation of the encoded videos.
12791279
"""
1280+
if self.config.patch_size is not None:
1281+
x = patchify(x, patch_size=self.config.patch_size)
1282+
12801283
_, _, num_frames, height, width = x.shape
12811284
latent_height = height // self.spatial_compression_ratio
12821285
latent_width = width // self.spatial_compression_ratio
@@ -1311,7 +1314,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13111314
j : j + self.tile_sample_min_width,
13121315
]
13131316
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1314-
tile = self.quant_conv(tile)
1317+
# tile = self.quant_conv(tile)
13151318
time.append(tile)
13161319
row.append(torch.cat(time, dim=2))
13171320
rows.append(row)
@@ -1331,6 +1334,7 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
13311334
result_rows.append(torch.cat(result_row, dim=-1))
13321335

13331336
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1337+
enc = self.quant_conv(enc)
13341338
return enc
13351339

13361340
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
@@ -1347,6 +1351,8 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13471351
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
13481352
returned.
13491353
"""
1354+
z = self.post_quant_conv(z)
1355+
13501356
_, _, num_frames, height, width = z.shape
13511357
sample_height = height * self.spatial_compression_ratio
13521358
sample_width = width * self.spatial_compression_ratio
@@ -1370,8 +1376,11 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13701376
for k in range(num_frames):
13711377
self._conv_idx = [0]
13721378
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1373-
tile = self.post_quant_conv(tile)
1374-
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1379+
# tile = self.post_quant_conv(tile)
1380+
if k == 0:
1381+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx,first_chunk=True)
1382+
else:
1383+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
13751384
time.append(decoded)
13761385
row.append(torch.cat(time, dim=2))
13771386
rows.append(row)
@@ -1392,6 +1401,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod
13921401

13931402
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
13941403

1404+
if self.config.patch_size is not None:
1405+
dec = unpatchify(dec, patch_size=self.config.patch_size)
1406+
dec = torch.clamp(dec, min=-1.0, max=1.0)
1407+
13951408
if not return_dict:
13961409
return (dec,)
13971410
return DecoderOutput(sample=dec)

0 commit comments

Comments
 (0)