From 0cc20ee3202521bfc0f5ae38a6b144da236fced2 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 21 Oct 2025 06:34:11 -0700 Subject: [PATCH 1/2] fix crash in tiling mode is enabled Signed-off-by: Wang, Yi A --- .../models/autoencoders/autoencoder_kl_wan.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index f95c4cf37475..ebba7d8fd818 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1355,9 +1355,18 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio - - blend_height = self.tile_sample_min_height - self.tile_sample_stride_height - blend_width = self.tile_sample_min_width - self.tile_sample_stride_width + tile_sample_stride_height = self.tile_sample_stride_height + tile_sample_stride_width = self.tile_sample_stride_width + if self.config.patch_size is not None: + sample_height = sample_height // self.config.patch_size + sample_width = sample_width // self.config.patch_size + tile_sample_stride_height = tile_sample_stride_height // self.config.patch_size + tile_sample_stride_width = tile_sample_stride_width // self.config.patch_size + blend_height = self.tile_sample_min_height // self.config.patch_size - tile_sample_stride_height + blend_width = self.tile_sample_min_width // self.config.patch_size - tile_sample_stride_width + else: + blend_height = self.tile_sample_min_height - tile_sample_stride_height + blend_width = self.tile_sample_min_width - tile_sample_stride_width # Split z into overlapping tiles and decode them separately. # The tiles have an overlap to avoid seams between tiles. @@ -1371,7 +1380,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod self._conv_idx = [0] tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx) + decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k==0)) time.append(decoded) row.append(torch.cat(time, dim=2)) rows.append(row) @@ -1387,11 +1396,15 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) - result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width]) + result_row.append(tile[:, :, :, : tile_sample_stride_height, : tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) - dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + + dec = torch.clamp(dec, min=-1.0, max=1.0) + if not return_dict: return (dec,) return DecoderOutput(sample=dec) From d777895075bb4ac08c1c98319dfcc14e6d0272ff Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Tue, 21 Oct 2025 06:46:16 -0700 Subject: [PATCH 2/2] fmt Signed-off-by: Wang, Yi A --- src/diffusers/models/autoencoders/autoencoder_kl_wan.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index ebba7d8fd818..8732d7368b4c 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1380,7 +1380,9 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod self._conv_idx = [0] tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width] tile = self.post_quant_conv(tile) - decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k==0)) + decoded = self.decoder( + tile, feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=(k == 0) + ) time.append(decoded) row.append(torch.cat(time, dim=2)) rows.append(row) @@ -1396,7 +1398,7 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod tile = self.blend_v(rows[i - 1][j], tile, blend_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_width) - result_row.append(tile[:, :, :, : tile_sample_stride_height, : tile_sample_stride_width]) + result_row.append(tile[:, :, :, :tile_sample_stride_height, :tile_sample_stride_width]) result_rows.append(torch.cat(result_row, dim=-1)) dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]