Skip to content

Commit b7a3900

Browse files
committed
one less nn.sequential
1 parent 2da1feb commit b7a3900

File tree

1 file changed

+33
-33
lines changed

1 file changed

+33
-33
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_wan.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -530,11 +530,10 @@ def __init__(
530530
self.middle = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
531531

532532
# output blocks
533-
self.head = nn.Sequential(
534-
WanRMS_norm(out_dim, images=False),
535-
self.nonlinearity,
536-
WanCausalConv3d(out_dim, z_dim, 3, padding=1)
537-
)
533+
self.norm_out = WanRMS_norm(out_dim, images=False)
534+
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
535+
536+
self.gradient_checkpointing = False
538537

539538
def forward(self, x, feat_cache=None, feat_idx=[0]):
540539
if feat_cache is not None:
@@ -560,18 +559,19 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
560559
x = self.middle(x, feat_cache, feat_idx)
561560

562561
## head
563-
for layer in self.head:
564-
if isinstance(layer, WanCausalConv3d) and feat_cache is not None:
565-
idx = feat_idx[0]
566-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
567-
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
562+
x = self.norm_out(x)
563+
x = self.nonlinearity(x)
564+
if feat_cache is not None:
565+
idx = feat_idx[0]
566+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
567+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
568568
# cache last frame of last two chunk
569-
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
570-
x = layer(x, feat_cache[idx])
571-
feat_cache[idx] = cache_x
572-
feat_idx[0] += 1
573-
else:
574-
x = layer(x)
569+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
570+
x = self.conv_out(x, feat_cache[idx])
571+
feat_cache[idx] = cache_x
572+
feat_idx[0] += 1
573+
else:
574+
x = self.conv_out(x)
575575
return x
576576

577577

@@ -719,11 +719,10 @@ def __init__(
719719
self.upsamples = upsamples
720720

721721
# output blocks
722-
self.head = nn.Sequential(
723-
WanRMS_norm(out_dim, images=False),
724-
self.nonlinearity,
725-
WanCausalConv3d(out_dim, 3, 3, padding=1)
726-
)
722+
self.norm_out = WanRMS_norm(out_dim, images=False)
723+
self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
724+
725+
self.gradient_checkpointing = False
727726

728727
def forward(self, x, feat_cache=None, feat_idx=[0]):
729728
## conv1
@@ -747,18 +746,19 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
747746
x = up_block(x, feat_cache, feat_idx)
748747

749748
## head
750-
for layer in self.head:
751-
if isinstance(layer, WanCausalConv3d) and feat_cache is not None:
752-
idx = feat_idx[0]
753-
cache_x = x[:, :, -CACHE_T:, :, :].clone()
754-
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
755-
# cache last frame of last two chunk
756-
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
757-
x = layer(x, feat_cache[idx])
758-
feat_cache[idx] = cache_x
759-
feat_idx[0] += 1
760-
else:
761-
x = layer(x)
749+
x = self.norm_out(x)
750+
x = self.nonlinearity(x)
751+
if feat_cache is not None:
752+
idx = feat_idx[0]
753+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
754+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
755+
# cache last frame of last two chunk
756+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
757+
x = self.conv_out(x, feat_cache[idx])
758+
feat_cache[idx] = cache_x
759+
feat_idx[0] += 1
760+
else:
761+
x = self.conv_out(x)
762762
return x
763763

764764

0 commit comments

Comments
 (0)