Skip to content

Commit 85a9825

Browse files
committed
init
1 parent 76c00c7 commit 85a9825

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright 2024 The Mochi team and The HuggingFace Team.
2+
# All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Dict, Optional, Tuple, Union
17+
18+
import numpy as np
19+
import torch
20+
import torch.nn as nn
21+
import torch.nn.functional as F
22+
23+
from ...configuration_utils import ConfigMixin, register_to_config
24+
from ...loaders.single_file_model import FromOriginalModelMixin
25+
from ...utils import logging
26+
from ...utils.accelerate_utils import apply_forward_hook
27+
from ..activations import get_activation
28+
from ..downsampling import CogVideoXDownsample3D
29+
from ..modeling_outputs import AutoencoderKLOutput
30+
from ..modeling_utils import ModelMixin
31+
from ..upsampling import CogVideoXUpsample3D
32+
from .vae import DecoderOutput, DiagonalGaussianDistribution
33+
34+
35+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36+
37+
38+
import torch
39+
import torch.nn as nn
40+
import torch.nn.functional as F
41+
42+
43+
class MochiCausalConv3d(nn.Module):
44+
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
45+
46+
Args:
47+
in_channels (`int`): Number of channels in the input tensor.
48+
out_channels (`int`): Number of output channels produced by the convolution.
49+
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
50+
stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution.
51+
pad_mode (`str`, defaults to `"constant"`): Padding mode.
52+
"""
53+
54+
def __init__(
55+
self,
56+
in_channels: int,
57+
out_channels: int,
58+
kernel_size: Union[int, Tuple[int, int, int]],
59+
stride: Union[int, Tuple[int, int, int]],
60+
padding_mode: str = "replicate",
61+
):
62+
super().__init__()
63+
64+
if isinstance(kernel_size, int):
65+
kernel_size = (kernel_size,) * 3
66+
if isinstance(stride, int):
67+
stride = (stride,) * 3
68+
69+
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
70+
71+
self.padding_mode = padding_mode
72+
height_pad = (height_kernel_size - 1) // 2
73+
width_pad = (width_kernel_size - 1) // 2
74+
75+
self.conv = nn.Conv3d(
76+
in_channels=in_channels,
77+
out_channels=out_channels,
78+
kernel_size=kernel_size,
79+
stride=stride,
80+
dilation=(1, 1, 1),
81+
padding=(0, height_pad, width_pad),
82+
padding_mode=padding_mode,
83+
)
84+
self.time_kernel_size = time_kernel_size
85+
86+
87+
88+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
89+
context_size = self.time_kernel_size - 1
90+
time_casual_padding = (0, 0, 0, 0, context_size, 0)
91+
92+
inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode)
93+
94+
# Memory-efficient chunked operation
95+
memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3
96+
if memory_count > 2:
97+
part_num = int(memory_count / 2) + 1
98+
k = self.time_kernel_size
99+
input_idx = torch.arange(context_size, inputs.size(2))
100+
input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num)
101+
102+
# Compute output size
103+
B, _, T_in, H_in, W_in = inputs.shape
104+
output_size = (
105+
B,
106+
self.conv.out_channels,
107+
T_in - k + 1,
108+
H_in // self.conv.stride[1],
109+
W_in // self.conv.stride[2],
110+
)
111+
output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device)
112+
for input_chunk_idx in input_chunks_idx:
113+
input_s = input_chunk_idx[0] - k + 1
114+
input_e = input_chunk_idx[-1] + 1
115+
input_chunk = inputs[:, :, input_s:input_e, :, :]
116+
output_chunk = self.conv(input_chunk)
117+
118+
output_s = input_s
119+
output_e = output_s + output_chunk.size(2)
120+
output[:, :, output_s:output_e, :, :] = output_chunk
121+
122+
return output
123+
else:
124+
return self.conv(inputs)
125+
126+
127+
class MochiGroupNorm3D(nn.Module):
128+
r"""
129+
Group normalization applied per-frame.
130+
131+
Args:
132+
133+
"""
134+
135+
def __init__(
136+
self,
137+
chunk_size: int = 8,
138+
):
139+
super().__init__()
140+
self.norm_layer = nn.GroupNorm()
141+
self.chunk_size = chunk_size
142+
143+
def forward(
144+
self, x: torch.Tensor = None
145+
) -> torch.Tensor:
146+
147+
batch_size, channels, num_frames, height, width = x.shape
148+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
149+
150+
num_chunks = (batch_size * num_frames + self.chunk_size - 1) // self.chunk_size
151+
152+
output = torch.cat(
153+
[self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)],
154+
dim=0
155+
)
156+
output = output.view(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
157+
158+
return output
159+
160+

0 commit comments

Comments
 (0)