Skip to content
Merged
23 changes: 14 additions & 9 deletions src/diffusers/models/autoencoders/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,12 @@ def __init__(
self.tile_sample_stride_height = 192
self.tile_sample_stride_width = 192

# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
self._cached_conv_counts = {
"decoder": self._count_conv3d_fast(self.decoder),
"encoder": self._count_conv3d_fast(self.encoder),
}

def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
Expand Down Expand Up @@ -801,18 +807,12 @@ def disable_slicing(self) -> None:
self.use_slicing = False

def clear_cache(self):
def _count_conv3d(model):
count = 0
for m in model.modules():
if isinstance(m, WanCausalConv3d):
count += 1
return count

self._conv_num = _count_conv3d(self.decoder)
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
self._conv_num = self._cached_conv_counts["decoder"]
self._conv_idx = [0]
self._feat_map = [None] * self._conv_num
# cache encode
self._enc_conv_num = _count_conv3d(self.encoder)
self._enc_conv_num = self._cached_conv_counts["encoder"]
self._enc_conv_idx = [0]
self._enc_feat_map = [None] * self._enc_conv_num

Expand Down Expand Up @@ -1083,3 +1083,8 @@ def forward(
z = posterior.mode()
dec = self.decode(z, return_dict=return_dict)
return dec

@staticmethod
def _count_conv3d_fast(model):
# Fast version: relies on model.modules() being a generator; avoids Python loop overhead by using sum + generator expression
return sum(isinstance(m, WanCausalConv3d) for m in model.modules())