Skip to content

Commit 9e8ef93

Browse files
committed
more
1 parent 89f1c6a commit 9e8ef93

File tree

1 file changed

+22
-108
lines changed

1 file changed

+22
-108
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 22 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -86,90 +86,6 @@ def forward(self, x, cache_x=None):
8686
return super().forward(x)
8787

8888

89-
# TODO: not used yet, will not affect the state dict so can be refactored in follow up PR
90-
class WanCausalConv3dYiYi(nn.Conv3d):
91-
r"""
92-
A custom 3D causal convolution layer with feature caching support.
93-
94-
This layer extends the standard Conv3D layer by ensuring causality in the time dimension
95-
and handling feature caching for efficient inference.
96-
97-
Args:
98-
in_channels (int): Number of channels in the input image
99-
out_channels (int): Number of channels produced by the convolution
100-
kernel_size (int or tuple): Size of the convolving kernel
101-
stride (int or tuple, optional): Stride of the convolution. Default: 1
102-
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
103-
"""
104-
def __init__(
105-
self,
106-
in_channels: int,
107-
out_channels: int,
108-
kernel_size: Union[int, Tuple[int, int, int]],
109-
stride: Union[int, Tuple[int, int, int]] = 1,
110-
padding: Union[int, Tuple[int, int, int]] = 0,
111-
) -> None:
112-
super().__init__(
113-
in_channels=in_channels,
114-
out_channels=out_channels,
115-
kernel_size=kernel_size,
116-
stride=stride,
117-
padding=padding,
118-
)
119-
120-
# Set up causal padding
121-
self._padding = (
122-
self.padding[2],
123-
self.padding[2],
124-
self.padding[1],
125-
self.padding[1],
126-
2 * self.padding[0],
127-
0
128-
)
129-
self.padding = (0, 0, 0)
130-
131-
def forward(self, x, feat_cache=None, feat_idx=[0]):
132-
"""
133-
Forward pass with feature caching support.
134-
135-
Args:
136-
x (torch.Tensor): Input tensor
137-
feat_cache (list, optional): List to store cached features
138-
feat_idx (list, optional): List with a single integer indicating the current cache index
139-
140-
Returns:
141-
torch.Tensor: Output tensor after convolution
142-
"""
143-
# Handle feature caching
144-
if feat_cache is not None:
145-
idx = feat_idx[0]
146-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
147-
148-
# Concatenate with cached frame if available
149-
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
150-
# cache last frame of last two chunk
151-
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
152-
153-
# Apply padding and convolution with cached data
154-
padding = list(self._padding)
155-
if feat_cache[idx] is not None and self._padding[4] > 0 :
156-
x = torch.cat([feat_cache[idx], x], dim=2)
157-
padding[4] -= feat_cache[idx].shape[2]
158-
159-
x = F.pad(x, padding)
160-
result = super().forward(x)
161-
162-
# Update cache
163-
feat_cache[idx] = cache_x
164-
feat_idx[0] += 1
165-
166-
return result
167-
else:
168-
# Standard forward pass without caching
169-
x = F.pad(x, self._padding)
170-
return super().forward(x)
171-
172-
17389
class WanRMS_norm(nn.Module):
17490
r"""
17591
A custom RMS normalization layer.
@@ -501,26 +417,26 @@ def __init__(
501417
scale = 1.0
502418

503419
# init block
504-
self.conv1 = WanCausalConv3d(3, dims[0], 3, padding=1)
420+
self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
505421

506422
# downsample blocks
507-
self.downsamples = nn.ModuleList([])
423+
self.down_blocks = nn.ModuleList([])
508424
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
509425
# residual (+attention) blocks
510426
for _ in range(num_res_blocks):
511-
self.downsamples.append(WanResidualBlock(in_dim, out_dim, dropout))
427+
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
512428
if scale in attn_scales:
513-
self.downsamples.append(WanAttentionBlock(out_dim))
429+
self.down_blocks.append(WanAttentionBlock(out_dim))
514430
in_dim = out_dim
515431

516432
# downsample block
517433
if i != len(dim_mult) - 1:
518434
mode = 'downsample3d' if temperal_downsample[i] else 'downsample2d'
519-
self.downsamples.append(WanResample(out_dim, mode=mode))
435+
self.down_blocks.append(WanResample(out_dim, mode=mode))
520436
scale /= 2.0
521437

522438
# middle blocks
523-
self.middle = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
439+
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
524440

525441
# output blocks
526442
self.norm_out = WanRMS_norm(out_dim, images=False)
@@ -535,21 +451,21 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
535451
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
536452
# cache last frame of last two chunk
537453
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
538-
x = self.conv1(x, feat_cache[idx])
454+
x = self.conv_in(x, feat_cache[idx])
539455
feat_cache[idx] = cache_x
540456
feat_idx[0] += 1
541457
else:
542-
x = self.conv1(x)
458+
x = self.conv_in(x)
543459

544460
## downsamples
545-
for layer in self.downsamples:
461+
for layer in self.down_blocks:
546462
if feat_cache is not None:
547463
x = layer(x, feat_cache, feat_idx)
548464
else:
549465
x = layer(x)
550466

551467
## middle
552-
x = self.middle(x, feat_cache, feat_idx)
468+
x = self.mid_block(x, feat_cache, feat_idx)
553469

554470
## head
555471
x = self.norm_out(x)
@@ -676,14 +592,14 @@ def __init__(
676592
scale = 1.0 / 2 ** (len(dim_mult) - 2)
677593

678594
# init block
679-
self.conv1 = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
595+
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
680596

681597
# middle blocks
682-
self.middle = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
598+
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
683599

684600

685601
# upsample blocks
686-
upsamples = nn.ModuleList([])
602+
self.up_blocks = nn.ModuleList([])
687603
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
688604
# residual (+attention) blocks
689605
if i > 0:
@@ -703,14 +619,12 @@ def __init__(
703619
upsample_mode=upsample_mode,
704620
non_linearity=non_linearity,
705621
)
706-
upsamples.append(up_block)
622+
self.up_blocks.append(up_block)
707623

708624
# Update scale for next iteration
709625
if upsample_mode is not None:
710626
scale *= 2.0
711627

712-
self.upsamples = upsamples
713-
714628
# output blocks
715629
self.norm_out = WanRMS_norm(out_dim, images=False)
716630
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
@@ -725,17 +639,17 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
725639
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
726640
# cache last frame of last two chunk
727641
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
728-
x = self.conv1(x, feat_cache[idx])
642+
x = self.conv_in(x, feat_cache[idx])
729643
feat_cache[idx] = cache_x
730644
feat_idx[0] += 1
731645
else:
732-
x = self.conv1(x)
646+
x = self.conv_in(x)
733647

734648
## middle
735-
x = self.middle(x, feat_cache, feat_idx)
649+
x = self.mid_block(x, feat_cache, feat_idx)
736650

737651
## upsamples
738-
for up_block in self.upsamples:
652+
for up_block in self.up_blocks:
739653
x = up_block(x, feat_cache, feat_idx)
740654

741655
## head
@@ -796,8 +710,8 @@ def __init__(
796710
base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales,
797711
self.temperal_downsample, dropout
798712
)
799-
self.conv1 = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
800-
self.conv2 = WanCausalConv3d(z_dim, z_dim, 1)
713+
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
714+
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
801715

802716
self.decoder = WanDecoder3d(
803717
base_dim, z_dim, dim_mult, num_res_blocks, attn_scales,
@@ -834,7 +748,7 @@ def _encode(self, x: torch.Tensor) -> torch.Tensor:
834748
out_ = self.encoder(x[:,:,1+4*(i-1):1+4*i,:,:], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
835749
out = torch.cat([out, out_], 2)
836750

837-
enc = self.conv1(out)
751+
enc = self.quant_conv(out)
838752
mu, logvar = enc[:, :self.z_dim, :, :, :], enc[:, self.z_dim:, :, :, :]
839753
mu = (mu - self.scale[0].view(1, self.z_dim, 1, 1, 1)) * self.scale[1].view(1, self.z_dim, 1, 1, 1)
840754
logvar = (logvar - self.scale[0].view(1, self.z_dim, 1, 1, 1)) * self.scale[1].view(1, self.z_dim, 1, 1, 1)
@@ -870,7 +784,7 @@ def _decode(self, z: torch.Tensor, scale, return_dict: bool = True) -> Union[Dec
870784
z = z / self.scale[1].view(1, self.z_dim, 1, 1, 1) + self.scale[0].view(1, self.z_dim, 1, 1, 1)
871785

872786
iter_ = z.shape[2]
873-
x = self.conv2(z)
787+
x = self.post_quant_conv(z)
874788
for i in range(iter_):
875789
self._conv_idx = [0]
876790
if i == 0:

0 commit comments

Comments
 (0)