1313# See the License for the specific language governing permissions and
1414# limitations under the License.
1515
16- from typing import Dict , Optional , Tuple , Union
16+ from typing import Optional , Tuple , Union
1717
18- import numpy as np
1918import torch
2019import torch .nn as nn
2120import torch .nn .functional as F
2221
23- from ...configuration_utils import ConfigMixin , register_to_config
24- from ...loaders .single_file_model import FromOriginalModelMixin
2522from ...utils import logging
26- from ...utils .accelerate_utils import apply_forward_hook
2723from ..activations import get_activation
28- from ..downsampling import CogVideoXDownsample3D
29- from ..modeling_outputs import AutoencoderKLOutput
30- from ..modeling_utils import ModelMixin
31- from ..upsampling import CogVideoXUpsample3D
32- from .vae import DecoderOutput , DiagonalGaussianDistribution
3324
3425
3526logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3627
3728
38- import torch
39- import torch .nn as nn
40- import torch .nn .functional as F
41-
42-
4329# YiYi to-do: replace this with nn.Conv3d
4430class Conv1x1 (nn .Linear ):
4531 """*1x1 Conv implemented with a linear layer."""
@@ -60,17 +46,18 @@ def forward(self, x: torch.Tensor):
6046 x = super ().forward (x )
6147 x = x .movedim (- 1 , 1 )
6248 return x
63-
49+
6450
6551class MochiChunkedCausalConv3d (nn .Module ):
66- r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
52+ r"""A 3D causal convolution layer that pads the input tensor to ensure causality in Mochi Model.
53+ It also supports memory-efficient chunked 3D convolutions.
6754
6855 Args:
6956 in_channels (`int`): Number of channels in the input tensor.
7057 out_channels (`int`): Number of output channels produced by the convolution.
7158 kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
7259 stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution.
73- pad_mode (`str`, defaults to `"constant "`): Padding mode.
60+ padding_mode (`str`, defaults to `"replicate "`): Padding mode.
7461 """
7562
7663 def __init__ (
@@ -88,7 +75,7 @@ def __init__(
8875 if isinstance (stride , int ):
8976 stride = (stride ,) * 3
9077
91- time_kernel_size , height_kernel_size , width_kernel_size = kernel_size
78+ _ , height_kernel_size , width_kernel_size = kernel_size
9279
9380 self .padding_mode = padding_mode
9481 height_pad = (height_kernel_size - 1 ) // 2
@@ -104,18 +91,17 @@ def __init__(
10491 padding_mode = padding_mode ,
10592 )
10693
107-
108-
109- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
94+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
11095 time_kernel_size = self .conv .kernel_size [0 ]
11196 context_size = time_kernel_size - 1
11297 time_casual_padding = (0 , 0 , 0 , 0 , context_size , 0 )
11398 hidden_states = F .pad (hidden_states , time_casual_padding , mode = self .padding_mode )
114-
99+
115100 # Memory-efficient chunked operation
116101 memory_count = torch .prod (torch .tensor (hidden_states .shape )).item () * 2 / 1024 ** 3
117102 # YiYI Notes: testing only!! please remove
118103 memory_count = 3
104+ # YiYI Notes: this number 2 should be a config: max_memory_chunk_size (2 is 2GB)
119105 if memory_count > 2 :
120106 part_num = int (memory_count / 2 ) + 1
121107 num_frames = hidden_states .shape [2 ]
@@ -131,7 +117,7 @@ def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
131117 output_chunks .append (output_chunk ) # Append each output chunk to the list
132118
133119 # Concatenate all output chunks along the temporal dimension
134- hidden_states = torch .cat (output_chunks , dim = 2 )
120+ hidden_states = torch .cat (output_chunks , dim = 2 )
135121
136122 return hidden_states
137123 else :
@@ -140,9 +126,14 @@ def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
140126
141127class MochiChunkedGroupNorm3D (nn .Module ):
142128 r"""
143- Group normalization applied per-frame.
129+ Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
130+ normalization.
144131
145132 Args:
133+ num_channels (int): Number of channels expected in input
134+ num_groups (int, optional): Number of groups to separate the channels into. Default: 32
135+ affine (bool, optional): If True, this module has learnable affine parameters. Default: True
136+ chunk_size (int, optional): Size of each chunk for processing. Default: 8
146137
147138 """
148139
@@ -157,49 +148,27 @@ def __init__(
157148 self .norm_layer = nn .GroupNorm (num_channels = num_channels , num_groups = num_groups , affine = affine )
158149 self .chunk_size = chunk_size
159150
160- def forward (
161- self , x : torch .Tensor = None
162- ) -> torch .Tensor :
163-
151+ def forward (self , x : torch .Tensor = None ) -> torch .Tensor :
164152 batch_size , channels , num_frames , height , width = x .shape
165153 x = x .permute (0 , 2 , 1 , 3 , 4 ).reshape (batch_size * num_frames , channels , height , width )
166-
167- num_chunks = (batch_size * num_frames + self .chunk_size - 1 ) // self .chunk_size
168-
169- output = torch .cat (
170- [self .norm_layer (chunk ) for chunk in x .split (self .chunk_size , dim = 0 )],
171- dim = 0
172- )
154+
155+ output = torch .cat ([self .norm_layer (chunk ) for chunk in x .split (self .chunk_size , dim = 0 )], dim = 0 )
173156 output = output .view (batch_size , num_frames , channels , height , width ).permute (0 , 2 , 1 , 3 , 4 )
174-
157+
175158 return output
176159
177160
178161class MochiResnetBlock3D (nn .Module ):
179162 r"""
180- A 3D ResNet block used in the CogVideoX model.
163+ A 3D ResNet block used in the Mochi model.
181164
182165 Args:
183166 in_channels (`int`):
184167 Number of input channels.
185168 out_channels (`int`, *optional*):
186169 Number of output channels. If None, defaults to `in_channels`.
187- dropout (`float`, defaults to `0.0`):
188- Dropout rate.
189- temb_channels (`int`, defaults to `512`):
190- Number of time embedding channels.
191- groups (`int`, defaults to `32`):
192- Number of groups to separate the channels into for group normalization.
193- eps (`float`, defaults to `1e-6`):
194- Epsilon value for normalization layers.
195170 non_linearity (`str`, defaults to `"swish"`):
196171 Activation function to use.
197- conv_shortcut (bool, defaults to `False`):
198- Whether or not to use a convolution shortcut.
199- spatial_norm_dim (`int`, *optional*):
200- The dimension to use for spatial norm if it is to be used instead of group norm.
201- pad_mode (str, defaults to `"first"`):
202- Padding mode.
203172 """
204173
205174 def __init__ (
@@ -225,14 +194,12 @@ def __init__(
225194 in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , stride = 1
226195 )
227196
228-
229197 def forward (
230198 self ,
231199 inputs : torch .Tensor ,
232200 ) -> torch .Tensor :
233-
234201 hidden_states = inputs
235-
202+
236203 hidden_states = self .norm1 (hidden_states )
237204 hidden_states = self .nonlinearity (hidden_states )
238205 hidden_states = self .conv1 (hidden_states )
@@ -254,6 +221,12 @@ class MochiUpBlock3D(nn.Module):
254221 Number of input channels.
255222 out_channels (`int`, *optional*):
256223 Number of output channels. If None, defaults to `in_channels`.
224+ num_layers (`int`, defaults to `1`):
225+ Number of resnet blocks in the block.
226+ temporal_expansion (`int`, defaults to `2`):
227+ Temporal expansion factor.
228+ spatial_expansion (`int`, defaults to `2`):
229+ Spatial expansion factor.
257230 """
258231
259232 def __init__ (
@@ -290,8 +263,7 @@ def forward(
290263 ) -> torch .Tensor :
291264 r"""Forward method of the `MochiUpBlock3D` class."""
292265
293- for i , resnet in enumerate (self .resnets ):
294-
266+ for resnet in self .resnets :
295267 if self .training and self .gradient_checkpointing :
296268
297269 def create_custom_forward (module ):
@@ -322,10 +294,8 @@ def create_forward(*inputs):
322294 hidden_states = hidden_states .contiguous ().view (B , new_C , T * st , H * sh , W * sw )
323295
324296 if self .temporal_expansion > 1 :
325- print (f"x: { hidden_states .shape } " )
326297 # Drop the first self.temporal_expansion - 1 frames.
327298 hidden_states = hidden_states [:, :, self .temporal_expansion - 1 :]
328- print (f"x: { hidden_states .shape } " )
329299
330300 return hidden_states
331301
@@ -337,22 +307,20 @@ class MochiMidBlock3D(nn.Module):
337307 Args:
338308 in_channels (`int`):
339309 Number of input channels.
310+ num_layers (`int`, defaults to `3`):
311+ Number of resnet blocks in the block.
340312 """
341313
342- _supports_gradient_checkpointing = True
343-
344314 def __init__ (
345315 self ,
346- in_channels : int , # 768
316+ in_channels : int , # 768
347317 num_layers : int = 3 ,
348318 ):
349319 super ().__init__ ()
350320
351321 resnets = []
352322 for _ in range (num_layers ):
353- resnets .append (
354- MochiResnetBlock3D (in_channels = in_channels )
355- )
323+ resnets .append (MochiResnetBlock3D (in_channels = in_channels ))
356324 self .resnets = nn .ModuleList (resnets )
357325
358326 self .gradient_checkpointing = False
@@ -363,7 +331,7 @@ def forward(
363331 ) -> torch .Tensor :
364332 r"""Forward method of the `MochiMidBlock3D` class."""
365333
366- for i , resnet in enumerate ( self .resnets ) :
334+ for resnet in self .resnets :
367335 if self .training and self .gradient_checkpointing :
368336
369337 def create_custom_forward (module ):
@@ -372,22 +340,39 @@ def create_forward(*inputs):
372340
373341 return create_forward
374342
375- hidden_states = torch .utils .checkpoint .checkpoint (
376- create_custom_forward (resnet ), hidden_states
377- )
343+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (resnet ), hidden_states )
378344 else :
379345 hidden_states = resnet (hidden_states )
380346
381347 return hidden_states
382348
383349
384350class MochiDecoder3D (nn .Module ):
385- _supports_gradient_checkpointing = True
351+ r"""
352+ The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
353+ sample.
354+
355+ Args:
356+ in_channels (`int`, *optional*):
357+ The number of input channels.
358+ out_channels (`int`, *optional*):
359+ The number of output channels.
360+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
361+ The number of output channels for each block.
362+ layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
363+ The number of resnet blocks for each block.
364+ temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
365+ The temporal expansion factor for each of the up blocks.
366+ spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
367+ The spatial expansion factor for each of the up blocks.
368+ non_linearity (`str`, *optional*, defaults to `"swish"`):
369+ The non-linearity to use in the decoder.
370+ """
386371
387372 def __init__ (
388373 self ,
389- in_channels : int , # 12
390- out_channels : int , # 3
374+ in_channels : int , # 12
375+ out_channels : int , # 3
391376 block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 768 ),
392377 layers_per_block : Tuple [int , ...] = (3 , 3 , 4 , 6 , 3 ),
393378 temporal_expansions : Tuple [int , ...] = (1 , 2 , 3 ),
@@ -418,29 +403,36 @@ def __init__(
418403 num_layers = layers_per_block [0 ],
419404 )
420405 self .conv_out = Conv1x1 (block_out_channels [0 ], out_channels )
421-
406+
407+ self .gradient_checkpointing = False
408+
422409 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
423410 r"""Forward method of the `MochiDecoder3D` class."""
424411
425- print (f"hidden_states: { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
426412 hidden_states = self .conv_in (hidden_states )
427- print (f"hidden_states (after conv_in): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
428-
429413
430414 # 1. Mid
431- hidden_states = self .block_in (hidden_states )
432- print (f"hidden_states (after block_in): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
433- # 2. Up
434- for i , up_block in enumerate (self .up_blocks ):
435- hidden_states = up_block (hidden_states )
436- print (f"hidden_states (after up_block { i } ): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
437- # 3. Post-process
415+ if self .training and self .gradient_checkpointing :
416+
417+ def create_custom_forward (module ):
418+ def create_forward (* inputs ):
419+ return module (* inputs )
420+
421+ return create_forward
422+
423+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (self .block_in ), hidden_states )
424+
425+ for up_block in self .up_blocks :
426+ hidden_states = torch .utils .checkpoint .checkpoint (create_custom_forward (up_block ), hidden_states )
427+ else :
428+ hidden_states = self .block_in (hidden_states )
429+
430+ for up_block in self .up_blocks :
431+ hidden_states = up_block (hidden_states )
432+
438433 hidden_states = self .block_out (hidden_states )
439- print (f"hidden_states (after block_out): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
440-
434+
441435 hidden_states = self .nonlinearity (hidden_states )
442436 hidden_states = self .conv_out (hidden_states )
443- print (f"hidden_states (after conv_out): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
444437
445438 return hidden_states
446-
0 commit comments