4040import torch .nn .functional as F
4141
4242
43- class MochiCausalConv3d (nn .Module ):
43+ # YiYi to-do: replace this with nn.Conv3d
44+ class Conv1x1 (nn .Linear ):
45+ """*1x1 Conv implemented with a linear layer."""
46+
47+ def __init__ (self , in_features : int , out_features : int , * args , ** kwargs ):
48+ super ().__init__ (in_features , out_features , * args , ** kwargs )
49+
50+ def forward (self , x : torch .Tensor ):
51+ """Forward pass.
52+
53+ Args:
54+ x: Input tensor. Shape: [B, C, *] or [B, *, C].
55+
56+ Returns:
57+ x: Output tensor. Shape: [B, C', *] or [B, *, C'].
58+ """
59+ x = x .movedim (1 , - 1 )
60+ x = super ().forward (x )
61+ x = x .movedim (- 1 , 1 )
62+ return x
63+
64+
65+ class MochiChunkedCausalConv3d (nn .Module ):
4466 r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
4567
4668 Args:
@@ -81,50 +103,42 @@ def __init__(
81103 padding = (0 , height_pad , width_pad ),
82104 padding_mode = padding_mode ,
83105 )
84- self .time_kernel_size = time_kernel_size
85106
86107
87108
88- def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
89- context_size = self .time_kernel_size - 1
109+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
110+ time_kernel_size = self .conv .kernel_size [0 ]
111+ context_size = time_kernel_size - 1
90112 time_casual_padding = (0 , 0 , 0 , 0 , context_size , 0 )
91-
92- inputs = F .pad (inputs , time_casual_padding , mode = self .padding_mode )
113+ hidden_states = F .pad (hidden_states , time_casual_padding , mode = self .padding_mode )
93114
94115 # Memory-efficient chunked operation
95- memory_count = torch .prod (torch .tensor (inputs .shape )).item () * 2 / 1024 ** 3
116+ memory_count = torch .prod (torch .tensor (hidden_states .shape )).item () * 2 / 1024 ** 3
117+ # YiYI Notes: testing only!! please remove
118+ memory_count = 3
96119 if memory_count > 2 :
97120 part_num = int (memory_count / 2 ) + 1
98- k = self .time_kernel_size
99- input_idx = torch .arange (context_size , inputs .size (2 ))
100- input_chunks_idx = torch .split (input_idx , input_idx .size (0 ) // part_num )
101-
102- # Compute output size
103- B , _ , T_in , H_in , W_in = inputs .shape
104- output_size = (
105- B ,
106- self .conv .out_channels ,
107- T_in - k + 1 ,
108- H_in // self .conv .stride [1 ],
109- W_in // self .conv .stride [2 ],
110- )
111- output = torch .empty (output_size , dtype = inputs .dtype , device = inputs .device )
112- for input_chunk_idx in input_chunks_idx :
113- input_s = input_chunk_idx [0 ] - k + 1
114- input_e = input_chunk_idx [- 1 ] + 1
115- input_chunk = inputs [:, :, input_s :input_e , :, :]
116- output_chunk = self .conv (input_chunk )
117-
118- output_s = input_s
119- output_e = output_s + output_chunk .size (2 )
120- output [:, :, output_s :output_e , :, :] = output_chunk
121-
122- return output
121+ num_frames = hidden_states .shape [2 ]
122+ frames_idx = torch .arange (context_size , num_frames )
123+ frames_chunks_idx = torch .chunk (frames_idx , part_num , dim = 0 )
124+
125+ output_chunks = []
126+ for frames_chunk_idx in frames_chunks_idx :
127+ frames_s = frames_chunk_idx [0 ] - context_size
128+ frames_e = frames_chunk_idx [- 1 ] + 1
129+ frames_chunk = hidden_states [:, :, frames_s :frames_e , :, :]
130+ output_chunk = self .conv (frames_chunk )
131+ output_chunks .append (output_chunk ) # Append each output chunk to the list
132+
133+ # Concatenate all output chunks along the temporal dimension
134+ hidden_states = torch .cat (output_chunks , dim = 2 )
135+
136+ return hidden_states
123137 else :
124- return self .conv (inputs )
138+ return self .conv (hidden_states )
125139
126140
127- class MochiGroupNorm3D (nn .Module ):
141+ class MochiChunkedGroupNorm3D (nn .Module ):
128142 r"""
129143 Group normalization applied per-frame.
130144
@@ -134,10 +148,13 @@ class MochiGroupNorm3D(nn.Module):
134148
135149 def __init__ (
136150 self ,
151+ num_channels : int ,
152+ num_groups : int = 32 ,
153+ affine : bool = True ,
137154 chunk_size : int = 8 ,
138155 ):
139156 super ().__init__ ()
140- self .norm_layer = nn .GroupNorm ()
157+ self .norm_layer = nn .GroupNorm (num_channels = num_channels , num_groups = num_groups , affine = affine )
141158 self .chunk_size = chunk_size
142159
143160 def forward (
@@ -158,3 +175,272 @@ def forward(
158175 return output
159176
160177
178+ class MochiResnetBlock3D (nn .Module ):
179+ r"""
180+ A 3D ResNet block used in the CogVideoX model.
181+
182+ Args:
183+ in_channels (`int`):
184+ Number of input channels.
185+ out_channels (`int`, *optional*):
186+ Number of output channels. If None, defaults to `in_channels`.
187+ dropout (`float`, defaults to `0.0`):
188+ Dropout rate.
189+ temb_channels (`int`, defaults to `512`):
190+ Number of time embedding channels.
191+ groups (`int`, defaults to `32`):
192+ Number of groups to separate the channels into for group normalization.
193+ eps (`float`, defaults to `1e-6`):
194+ Epsilon value for normalization layers.
195+ non_linearity (`str`, defaults to `"swish"`):
196+ Activation function to use.
197+ conv_shortcut (bool, defaults to `False`):
198+ Whether or not to use a convolution shortcut.
199+ spatial_norm_dim (`int`, *optional*):
200+ The dimension to use for spatial norm if it is to be used instead of group norm.
201+ pad_mode (str, defaults to `"first"`):
202+ Padding mode.
203+ """
204+
205+ def __init__ (
206+ self ,
207+ in_channels : int ,
208+ out_channels : Optional [int ] = None ,
209+ non_linearity : str = "swish" ,
210+ ):
211+ super ().__init__ ()
212+
213+ out_channels = out_channels or in_channels
214+
215+ self .in_channels = in_channels
216+ self .out_channels = out_channels
217+ self .nonlinearity = get_activation (non_linearity )
218+
219+ self .norm1 = MochiChunkedGroupNorm3D (num_channels = in_channels )
220+ self .conv1 = MochiChunkedCausalConv3d (
221+ in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , stride = 1
222+ )
223+ self .norm2 = MochiChunkedGroupNorm3D (num_channels = out_channels )
224+ self .conv2 = MochiChunkedCausalConv3d (
225+ in_channels = out_channels , out_channels = out_channels , kernel_size = 3 , stride = 1
226+ )
227+
228+
229+ def forward (
230+ self ,
231+ inputs : torch .Tensor ,
232+ ) -> torch .Tensor :
233+
234+ hidden_states = inputs
235+
236+ hidden_states = self .norm1 (hidden_states )
237+ hidden_states = self .nonlinearity (hidden_states )
238+ hidden_states = self .conv1 (hidden_states )
239+
240+ hidden_states = self .norm2 (hidden_states )
241+ hidden_states = self .nonlinearity (hidden_states )
242+ hidden_states = self .conv2 (hidden_states )
243+
244+ hidden_states = hidden_states + inputs
245+ return hidden_states
246+
247+
248+ class MochiUpBlock3D (nn .Module ):
249+ r"""
250+ An upsampling block used in the Mochi model.
251+
252+ Args:
253+ in_channels (`int`):
254+ Number of input channels.
255+ out_channels (`int`, *optional*):
256+ Number of output channels. If None, defaults to `in_channels`.
257+ """
258+
259+ def __init__ (
260+ self ,
261+ in_channels : int ,
262+ out_channels : int ,
263+ num_layers : int = 1 ,
264+ temporal_expansion : int = 2 ,
265+ spatial_expansion : int = 2 ,
266+ ):
267+ super ().__init__ ()
268+ self .temporal_expansion = temporal_expansion
269+ self .spatial_expansion = spatial_expansion
270+
271+ resnets = []
272+ for i in range (num_layers ):
273+ resnets .append (
274+ MochiResnetBlock3D (
275+ in_channels = in_channels ,
276+ )
277+ )
278+
279+ self .resnets = nn .ModuleList (resnets )
280+ self .proj = Conv1x1 (
281+ in_channels ,
282+ out_channels * temporal_expansion * (spatial_expansion ** 2 ),
283+ )
284+
285+ self .gradient_checkpointing = False
286+
287+ def forward (
288+ self ,
289+ hidden_states : torch .Tensor ,
290+ ) -> torch .Tensor :
291+ r"""Forward method of the `MochiUpBlock3D` class."""
292+
293+ for i , resnet in enumerate (self .resnets ):
294+
295+ if self .training and self .gradient_checkpointing :
296+
297+ def create_custom_forward (module ):
298+ def create_forward (* inputs ):
299+ return module (* inputs )
300+
301+ return create_forward
302+
303+ hidden_states = torch .utils .checkpoint .checkpoint (
304+ create_custom_forward (resnet ),
305+ hidden_states ,
306+ )
307+ else :
308+ hidden_states = resnet (hidden_states )
309+
310+ hidden_states = self .proj (hidden_states )
311+
312+ # Calculate new shape
313+ B , C , T , H , W = hidden_states .shape
314+ st = self .temporal_expansion
315+ sh = self .spatial_expansion
316+ sw = self .spatial_expansion
317+ new_C = C // (st * sh * sw )
318+
319+ # Reshape and permute
320+ hidden_states = hidden_states .view (B , new_C , st , sh , sw , T , H , W )
321+ hidden_states = hidden_states .permute (0 , 1 , 5 , 2 , 6 , 3 , 7 , 4 )
322+ hidden_states = hidden_states .contiguous ().view (B , new_C , T * st , H * sh , W * sw )
323+
324+ if self .temporal_expansion > 1 :
325+ print (f"x: { hidden_states .shape } " )
326+ # Drop the first self.temporal_expansion - 1 frames.
327+ hidden_states = hidden_states [:, :, self .temporal_expansion - 1 :]
328+ print (f"x: { hidden_states .shape } " )
329+
330+ return hidden_states
331+
332+
333+ class MochiMidBlock3D (nn .Module ):
334+ r"""
335+ A middle block used in the Mochi model.
336+
337+ Args:
338+ in_channels (`int`):
339+ Number of input channels.
340+ """
341+
342+ _supports_gradient_checkpointing = True
343+
344+ def __init__ (
345+ self ,
346+ in_channels : int , # 768
347+ num_layers : int = 3 ,
348+ ):
349+ super ().__init__ ()
350+
351+ resnets = []
352+ for _ in range (num_layers ):
353+ resnets .append (
354+ MochiResnetBlock3D (in_channels = in_channels )
355+ )
356+ self .resnets = nn .ModuleList (resnets )
357+
358+ self .gradient_checkpointing = False
359+
360+ def forward (
361+ self ,
362+ hidden_states : torch .Tensor ,
363+ ) -> torch .Tensor :
364+ r"""Forward method of the `MochiMidBlock3D` class."""
365+
366+ for i , resnet in enumerate (self .resnets ):
367+ if self .training and self .gradient_checkpointing :
368+
369+ def create_custom_forward (module ):
370+ def create_forward (* inputs ):
371+ return module (* inputs )
372+
373+ return create_forward
374+
375+ hidden_states = torch .utils .checkpoint .checkpoint (
376+ create_custom_forward (resnet ), hidden_states
377+ )
378+ else :
379+ hidden_states = resnet (hidden_states )
380+
381+ return hidden_states
382+
383+
384+ class MochiDecoder3D (nn .Module ):
385+ _supports_gradient_checkpointing = True
386+
387+ def __init__ (
388+ self ,
389+ in_channels : int , # 12
390+ out_channels : int , # 3
391+ block_out_channels : Tuple [int , ...] = (128 , 256 , 512 , 768 ),
392+ layers_per_block : Tuple [int , ...] = (3 , 3 , 4 , 6 , 3 ),
393+ temporal_expansions : Tuple [int , ...] = (1 , 2 , 3 ),
394+ spatial_expansions : Tuple [int , ...] = (2 , 2 , 2 ),
395+ non_linearity : str = "swish" ,
396+ ):
397+ super ().__init__ ()
398+
399+ self .nonlinearity = get_activation (non_linearity )
400+
401+ self .conv_in = nn .Conv3d (in_channels , block_out_channels [- 1 ], kernel_size = (1 , 1 , 1 ))
402+ self .block_in = MochiMidBlock3D (
403+ in_channels = block_out_channels [- 1 ],
404+ num_layers = layers_per_block [- 1 ],
405+ )
406+ self .up_blocks = nn .ModuleList ([])
407+ for i in range (len (block_out_channels ) - 1 ):
408+ up_block = MochiUpBlock3D (
409+ in_channels = block_out_channels [- i - 1 ],
410+ out_channels = block_out_channels [- i - 2 ],
411+ num_layers = layers_per_block [- i - 2 ],
412+ temporal_expansion = temporal_expansions [- i - 1 ],
413+ spatial_expansion = spatial_expansions [- i - 1 ],
414+ )
415+ self .up_blocks .append (up_block )
416+ self .block_out = MochiMidBlock3D (
417+ in_channels = block_out_channels [0 ],
418+ num_layers = layers_per_block [0 ],
419+ )
420+ self .conv_out = Conv1x1 (block_out_channels [0 ], out_channels )
421+
422+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
423+ r"""Forward method of the `MochiDecoder3D` class."""
424+
425+ print (f"hidden_states: { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
426+ hidden_states = self .conv_in (hidden_states )
427+ print (f"hidden_states (after conv_in): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
428+
429+
430+ # 1. Mid
431+ hidden_states = self .block_in (hidden_states )
432+ print (f"hidden_states (after block_in): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
433+ # 2. Up
434+ for i , up_block in enumerate (self .up_blocks ):
435+ hidden_states = up_block (hidden_states )
436+ print (f"hidden_states (after up_block { i } ): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
437+ # 3. Post-process
438+ hidden_states = self .block_out (hidden_states )
439+ print (f"hidden_states (after block_out): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
440+
441+ hidden_states = self .nonlinearity (hidden_states )
442+ hidden_states = self .conv_out (hidden_states )
443+ print (f"hidden_states (after conv_out): { hidden_states .shape } , { hidden_states [0 ,:3 ,0 ,:2 ,:2 ]} " )
444+
445+ return hidden_states
446+
0 commit comments