Skip to content

Commit 5f2518a

Browse files
committed
up more
1 parent b7a3900 commit 5f2518a

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,6 @@
3535
CACHE_T = 2
3636

3737

38-
def count_conv3d(model):
39-
count = 0
40-
for m in model.modules():
41-
if isinstance(m, WanCausalConv3d):
42-
count += 1
43-
return count
44-
4538
class WanCausalConv3d(nn.Conv3d):
4639
r"""
4740
A custom 3D causal convolution layer with feature caching support.
@@ -82,6 +75,7 @@ def __init__(
8275
0
8376
)
8477
self.padding = (0, 0, 0)
78+
8579
def forward(self, x, cache_x=None):
8680
padding = list(self._padding)
8781
if cache_x is not None and self._padding[4] > 0:
@@ -175,6 +169,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
175169
x = F.pad(x, self._padding)
176170
return super().forward(x)
177171

172+
178173
class WanRMS_norm(nn.Module):
179174
r"""
180175
A custom RMS normalization layer.
@@ -221,7 +216,6 @@ def forward(self, x):
221216
return super().forward(x.float()).type_as(x)
222217

223218

224-
225219
class WanResample(nn.Module):
226220
r"""
227221
A custom resampling module for 2D and 3D data.
@@ -311,8 +305,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
311305
feat_idx[0] += 1
312306
return x
313307

314-
315-
308+
316309
class WanResidualBlock(nn.Module):
317310
r"""
318311
A custom residual block module.
@@ -812,11 +805,19 @@ def __init__(
812805
self.temperal_upsample, dropout
813806
)
814807
def clear_cache(self):
815-
self._conv_num = count_conv3d(self.decoder)
808+
809+
def _count_conv3d(model):
810+
count = 0
811+
for m in model.modules():
812+
if isinstance(m, WanCausalConv3d):
813+
count += 1
814+
return count
815+
816+
self._conv_num = _count_conv3d(self.decoder)
816817
self._conv_idx = [0]
817818
self._feat_map = [None] * self._conv_num
818819
#cache encode
819-
self._enc_conv_num = count_conv3d(self.encoder)
820+
self._enc_conv_num = _count_conv3d(self.encoder)
820821
self._enc_conv_idx = [0]
821822
self._enc_feat_map = [None] * self._enc_conv_num
822823

0 commit comments

Comments
 (0)