Skip to content

Commit 59d7793

Browse files
committed
update
1 parent de925be commit 59d7793

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101101
return hidden_states
102102

103103

104-
class CosmosPatcher3d(nn.Module):
104+
class CosmosPatchEmbed3d(nn.Module):
105105
def __init__(self, patch_size: int = 1, patch_method: str = "haar") -> None:
106106
super().__init__()
107107

@@ -255,7 +255,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
255255
raise ValueError("Unknown patch method: " + self.patch_method)
256256

257257

258-
class CosmosConvProj3d(nn.Module):
258+
class CosmosConvProjection3d(nn.Module):
259259
def __init__(self, in_channels: int, out_channels: int) -> None:
260260
super().__init__()
261261

@@ -280,11 +280,11 @@ def __init__(
280280
out_channels = out_channels or in_channels
281281

282282
self.norm1 = CosmosCausalGroupNorm(in_channels, num_groups)
283-
self.conv1 = CosmosConvProj3d(in_channels, out_channels)
283+
self.conv1 = CosmosConvProjection3d(in_channels, out_channels)
284284

285285
self.norm2 = CosmosCausalGroupNorm(out_channels, num_groups)
286286
self.dropout = nn.Dropout(dropout)
287-
self.conv2 = CosmosConvProj3d(out_channels, out_channels)
287+
self.conv2 = CosmosConvProjection3d(out_channels, out_channels)
288288

289289
if in_channels != out_channels:
290290
self.conv_shortcut = CosmosCausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
@@ -673,7 +673,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
673673
return hidden_states
674674

675675

676-
class CosmosEncoder(nn.Module):
676+
class CosmosEncoder3d(nn.Module):
677677
def __init__(
678678
self,
679679
in_channels: int = 3,
@@ -694,9 +694,9 @@ def __init__(
694694
num_temporal_layers = int(math.log2(temporal_compression_ratio)) - int(math.log2(patch_size))
695695

696696
# 1. Input patching & projection
697-
self.patch_embed = CosmosPatcher3d(patch_size, patch_type)
697+
self.patch_embed = CosmosPatchEmbed3d(patch_size, patch_type)
698698

699-
self.conv_in = CosmosConvProj3d(inner_dim, block_out_channels[0])
699+
self.conv_in = CosmosConvProjection3d(inner_dim, block_out_channels[0])
700700

701701
# 2. Down blocks
702702
current_resolution = resolution // patch_size
@@ -734,7 +734,7 @@ def __init__(
734734

735735
# 4. Output norm & projection
736736
self.norm_out = CosmosCausalGroupNorm(block_out_channels[-1], num_groups=1)
737-
self.conv_out = CosmosConvProj3d(block_out_channels[-1], out_channels)
737+
self.conv_out = CosmosConvProjection3d(block_out_channels[-1], out_channels)
738738

739739
self.gradient_checkpointing = False
740740

@@ -757,7 +757,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
757757
return hidden_states
758758

759759

760-
class CosmosDecoder(nn.Module):
760+
class CosmosDecoder3d(nn.Module):
761761
def __init__(
762762
self,
763763
in_channels: int = 16,
@@ -779,7 +779,7 @@ def __init__(
779779
reversed_block_out_channels = list(reversed(block_out_channels))
780780

781781
# 1. Input projection
782-
self.conv_in = CosmosConvProj3d(in_channels, reversed_block_out_channels[0])
782+
self.conv_in = CosmosConvProjection3d(in_channels, reversed_block_out_channels[0])
783783

784784
# 2. Mid block
785785
self.mid_block = CosmosMidBlock3d(reversed_block_out_channels[0], num_layers=1, dropout=dropout, num_groups=1)
@@ -819,7 +819,7 @@ def __init__(
819819

820820
# 4. Output norm & projection & unpatching
821821
self.norm_out = CosmosCausalGroupNorm(reversed_block_out_channels[-1], num_groups=1)
822-
self.conv_out = CosmosConvProj3d(reversed_block_out_channels[-1], inner_dim)
822+
self.conv_out = CosmosConvProjection3d(reversed_block_out_channels[-1], inner_dim)
823823

824824
self.unpatch_embed = CosmosUnpatcher3d(patch_size, patch_type)
825825

@@ -906,7 +906,7 @@ def __init__(
906906
) -> None:
907907
super().__init__()
908908

909-
self.encoder = CosmosEncoder(
909+
self.encoder = CosmosEncoder3d(
910910
in_channels=in_channels,
911911
out_channels=latent_channels,
912912
block_out_channels=encoder_block_out_channels,
@@ -918,7 +918,7 @@ def __init__(
918918
spatial_compression_ratio=spatial_compression_ratio,
919919
temporal_compression_ratio=temporal_compression_ratio,
920920
)
921-
self.decoder = CosmosDecoder(
921+
self.decoder = CosmosDecoder3d(
922922
in_channels=latent_channels,
923923
out_channels=out_channels,
924924
block_out_channels=decode_block_out_channels,

0 commit comments

Comments
 (0)