|
35 | 35 | CACHE_T = 2 |
36 | 36 |
|
37 | 37 |
|
38 | | -def count_conv3d(model): |
39 | | - count = 0 |
40 | | - for m in model.modules(): |
41 | | - if isinstance(m, WanCausalConv3d): |
42 | | - count += 1 |
43 | | - return count |
44 | | - |
45 | 38 | class WanCausalConv3d(nn.Conv3d): |
46 | 39 | r""" |
47 | 40 | A custom 3D causal convolution layer with feature caching support. |
@@ -82,6 +75,7 @@ def __init__( |
82 | 75 | 0 |
83 | 76 | ) |
84 | 77 | self.padding = (0, 0, 0) |
| 78 | + |
85 | 79 | def forward(self, x, cache_x=None): |
86 | 80 | padding = list(self._padding) |
87 | 81 | if cache_x is not None and self._padding[4] > 0: |
@@ -175,6 +169,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): |
175 | 169 | x = F.pad(x, self._padding) |
176 | 170 | return super().forward(x) |
177 | 171 |
|
| 172 | + |
178 | 173 | class WanRMS_norm(nn.Module): |
179 | 174 | r""" |
180 | 175 | A custom RMS normalization layer. |
@@ -221,7 +216,6 @@ def forward(self, x): |
221 | 216 | return super().forward(x.float()).type_as(x) |
222 | 217 |
|
223 | 218 |
|
224 | | - |
225 | 219 | class WanResample(nn.Module): |
226 | 220 | r""" |
227 | 221 | A custom resampling module for 2D and 3D data. |
@@ -311,8 +305,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]): |
311 | 305 | feat_idx[0] += 1 |
312 | 306 | return x |
313 | 307 |
|
314 | | - |
315 | | - |
| 308 | + |
316 | 309 | class WanResidualBlock(nn.Module): |
317 | 310 | r""" |
318 | 311 | A custom residual block module. |
@@ -812,11 +805,19 @@ def __init__( |
812 | 805 | self.temperal_upsample, dropout |
813 | 806 | ) |
814 | 807 | def clear_cache(self): |
815 | | - self._conv_num = count_conv3d(self.decoder) |
| 808 | + |
| 809 | + def _count_conv3d(model): |
| 810 | + count = 0 |
| 811 | + for m in model.modules(): |
| 812 | + if isinstance(m, WanCausalConv3d): |
| 813 | + count += 1 |
| 814 | + return count |
| 815 | + |
| 816 | + self._conv_num = _count_conv3d(self.decoder) |
816 | 817 | self._conv_idx = [0] |
817 | 818 | self._feat_map = [None] * self._conv_num |
818 | 819 | #cache encode |
819 | | - self._enc_conv_num = count_conv3d(self.encoder) |
| 820 | + self._enc_conv_num = _count_conv3d(self.encoder) |
820 | 821 | self._enc_conv_idx = [0] |
821 | 822 | self._enc_feat_map = [None] * self._enc_conv_num |
822 | 823 |
|
|
0 commit comments