@@ -101,7 +101,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
101101 return hidden_states
102102
103103
104- class CosmosPatcher3d (nn .Module ):
104+ class CosmosPatchEmbed3d (nn .Module ):
105105 def __init__ (self , patch_size : int = 1 , patch_method : str = "haar" ) -> None :
106106 super ().__init__ ()
107107
@@ -255,7 +255,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
255255 raise ValueError ("Unknown patch method: " + self .patch_method )
256256
257257
258- class CosmosConvProj3d (nn .Module ):
258+ class CosmosConvProjection3d (nn .Module ):
259259 def __init__ (self , in_channels : int , out_channels : int ) -> None :
260260 super ().__init__ ()
261261
@@ -280,11 +280,11 @@ def __init__(
280280 out_channels = out_channels or in_channels
281281
282282 self .norm1 = CosmosCausalGroupNorm (in_channels , num_groups )
283- self .conv1 = CosmosConvProj3d (in_channels , out_channels )
283+ self .conv1 = CosmosConvProjection3d (in_channels , out_channels )
284284
285285 self .norm2 = CosmosCausalGroupNorm (out_channels , num_groups )
286286 self .dropout = nn .Dropout (dropout )
287- self .conv2 = CosmosConvProj3d (out_channels , out_channels )
287+ self .conv2 = CosmosConvProjection3d (out_channels , out_channels )
288288
289289 if in_channels != out_channels :
290290 self .conv_shortcut = CosmosCausalConv3d (in_channels , out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
@@ -673,7 +673,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
673673 return hidden_states
674674
675675
676- class CosmosEncoder (nn .Module ):
676+ class CosmosEncoder3d (nn .Module ):
677677 def __init__ (
678678 self ,
679679 in_channels : int = 3 ,
@@ -694,9 +694,9 @@ def __init__(
694694 num_temporal_layers = int (math .log2 (temporal_compression_ratio )) - int (math .log2 (patch_size ))
695695
696696 # 1. Input patching & projection
697- self .patch_embed = CosmosPatcher3d (patch_size , patch_type )
697+ self .patch_embed = CosmosPatchEmbed3d (patch_size , patch_type )
698698
699- self .conv_in = CosmosConvProj3d (inner_dim , block_out_channels [0 ])
699+ self .conv_in = CosmosConvProjection3d (inner_dim , block_out_channels [0 ])
700700
701701 # 2. Down blocks
702702 current_resolution = resolution // patch_size
@@ -734,7 +734,7 @@ def __init__(
734734
735735 # 4. Output norm & projection
736736 self .norm_out = CosmosCausalGroupNorm (block_out_channels [- 1 ], num_groups = 1 )
737- self .conv_out = CosmosConvProj3d (block_out_channels [- 1 ], out_channels )
737+ self .conv_out = CosmosConvProjection3d (block_out_channels [- 1 ], out_channels )
738738
739739 self .gradient_checkpointing = False
740740
@@ -757,7 +757,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
757757 return hidden_states
758758
759759
760- class CosmosDecoder (nn .Module ):
760+ class CosmosDecoder3d (nn .Module ):
761761 def __init__ (
762762 self ,
763763 in_channels : int = 16 ,
@@ -779,7 +779,7 @@ def __init__(
779779 reversed_block_out_channels = list (reversed (block_out_channels ))
780780
781781 # 1. Input projection
782- self .conv_in = CosmosConvProj3d (in_channels , reversed_block_out_channels [0 ])
782+ self .conv_in = CosmosConvProjection3d (in_channels , reversed_block_out_channels [0 ])
783783
784784 # 2. Mid block
785785 self .mid_block = CosmosMidBlock3d (reversed_block_out_channels [0 ], num_layers = 1 , dropout = dropout , num_groups = 1 )
@@ -819,7 +819,7 @@ def __init__(
819819
820820 # 4. Output norm & projection & unpatching
821821 self .norm_out = CosmosCausalGroupNorm (reversed_block_out_channels [- 1 ], num_groups = 1 )
822- self .conv_out = CosmosConvProj3d (reversed_block_out_channels [- 1 ], inner_dim )
822+ self .conv_out = CosmosConvProjection3d (reversed_block_out_channels [- 1 ], inner_dim )
823823
824824 self .unpatch_embed = CosmosUnpatcher3d (patch_size , patch_type )
825825
@@ -906,7 +906,7 @@ def __init__(
906906 ) -> None :
907907 super ().__init__ ()
908908
909- self .encoder = CosmosEncoder (
909+ self .encoder = CosmosEncoder3d (
910910 in_channels = in_channels ,
911911 out_channels = latent_channels ,
912912 block_out_channels = encoder_block_out_channels ,
@@ -918,7 +918,7 @@ def __init__(
918918 spatial_compression_ratio = spatial_compression_ratio ,
919919 temporal_compression_ratio = temporal_compression_ratio ,
920920 )
921- self .decoder = CosmosDecoder (
921+ self .decoder = CosmosDecoder3d (
922922 in_channels = latent_channels ,
923923 out_channels = out_channels ,
924924 block_out_channels = decode_block_out_channels ,
0 commit comments