Skip to content

Commit 0a6189e

Browse files
committed
add
1 parent 85a9825 commit 0a6189e

File tree

1 file changed

+321
-35
lines changed

1 file changed

+321
-35
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 321 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,29 @@
4040
import torch.nn.functional as F
4141

4242

43-
class MochiCausalConv3d(nn.Module):
43+
# YiYi to-do: replace this with nn.Conv3d
44+
class Conv1x1(nn.Linear):
45+
"""*1x1 Conv implemented with a linear layer."""
46+
47+
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
48+
super().__init__(in_features, out_features, *args, **kwargs)
49+
50+
def forward(self, x: torch.Tensor):
51+
"""Forward pass.
52+
53+
Args:
54+
x: Input tensor. Shape: [B, C, *] or [B, *, C].
55+
56+
Returns:
57+
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
58+
"""
59+
x = x.movedim(1, -1)
60+
x = super().forward(x)
61+
x = x.movedim(-1, 1)
62+
return x
63+
64+
65+
class MochiChunkedCausalConv3d(nn.Module):
4466
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
4567
4668
Args:
@@ -81,50 +103,42 @@ def __init__(
81103
padding=(0, height_pad, width_pad),
82104
padding_mode=padding_mode,
83105
)
84-
self.time_kernel_size = time_kernel_size
85106

86107

87108

88-
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
89-
context_size = self.time_kernel_size - 1
109+
def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
110+
time_kernel_size = self.conv.kernel_size[0]
111+
context_size = time_kernel_size - 1
90112
time_casual_padding = (0, 0, 0, 0, context_size, 0)
91-
92-
inputs = F.pad(inputs, time_casual_padding, mode=self.padding_mode)
113+
hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode)
93114

94115
# Memory-efficient chunked operation
95-
memory_count = torch.prod(torch.tensor(inputs.shape)).item() * 2 / 1024**3
116+
memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3
117+
# YiYI Notes: testing only!! please remove
118+
memory_count = 3
96119
if memory_count > 2:
97120
part_num = int(memory_count / 2) + 1
98-
k = self.time_kernel_size
99-
input_idx = torch.arange(context_size, inputs.size(2))
100-
input_chunks_idx = torch.split(input_idx, input_idx.size(0) // part_num)
101-
102-
# Compute output size
103-
B, _, T_in, H_in, W_in = inputs.shape
104-
output_size = (
105-
B,
106-
self.conv.out_channels,
107-
T_in - k + 1,
108-
H_in // self.conv.stride[1],
109-
W_in // self.conv.stride[2],
110-
)
111-
output = torch.empty(output_size, dtype=inputs.dtype, device=inputs.device)
112-
for input_chunk_idx in input_chunks_idx:
113-
input_s = input_chunk_idx[0] - k + 1
114-
input_e = input_chunk_idx[-1] + 1
115-
input_chunk = inputs[:, :, input_s:input_e, :, :]
116-
output_chunk = self.conv(input_chunk)
117-
118-
output_s = input_s
119-
output_e = output_s + output_chunk.size(2)
120-
output[:, :, output_s:output_e, :, :] = output_chunk
121-
122-
return output
121+
num_frames = hidden_states.shape[2]
122+
frames_idx = torch.arange(context_size, num_frames)
123+
frames_chunks_idx = torch.chunk(frames_idx, part_num, dim=0)
124+
125+
output_chunks = []
126+
for frames_chunk_idx in frames_chunks_idx:
127+
frames_s = frames_chunk_idx[0] - context_size
128+
frames_e = frames_chunk_idx[-1] + 1
129+
frames_chunk = hidden_states[:, :, frames_s:frames_e, :, :]
130+
output_chunk = self.conv(frames_chunk)
131+
output_chunks.append(output_chunk) # Append each output chunk to the list
132+
133+
# Concatenate all output chunks along the temporal dimension
134+
hidden_states = torch.cat(output_chunks, dim=2)
135+
136+
return hidden_states
123137
else:
124-
return self.conv(inputs)
138+
return self.conv(hidden_states)
125139

126140

127-
class MochiGroupNorm3D(nn.Module):
141+
class MochiChunkedGroupNorm3D(nn.Module):
128142
r"""
129143
Group normalization applied per-frame.
130144
@@ -134,10 +148,13 @@ class MochiGroupNorm3D(nn.Module):
134148

135149
def __init__(
136150
self,
151+
num_channels: int,
152+
num_groups: int = 32,
153+
affine: bool = True,
137154
chunk_size: int = 8,
138155
):
139156
super().__init__()
140-
self.norm_layer = nn.GroupNorm()
157+
self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
141158
self.chunk_size = chunk_size
142159

143160
def forward(
@@ -158,3 +175,272 @@ def forward(
158175
return output
159176

160177

178+
class MochiResnetBlock3D(nn.Module):
179+
r"""
180+
A 3D ResNet block used in the CogVideoX model.
181+
182+
Args:
183+
in_channels (`int`):
184+
Number of input channels.
185+
out_channels (`int`, *optional*):
186+
Number of output channels. If None, defaults to `in_channels`.
187+
dropout (`float`, defaults to `0.0`):
188+
Dropout rate.
189+
temb_channels (`int`, defaults to `512`):
190+
Number of time embedding channels.
191+
groups (`int`, defaults to `32`):
192+
Number of groups to separate the channels into for group normalization.
193+
eps (`float`, defaults to `1e-6`):
194+
Epsilon value for normalization layers.
195+
non_linearity (`str`, defaults to `"swish"`):
196+
Activation function to use.
197+
conv_shortcut (bool, defaults to `False`):
198+
Whether or not to use a convolution shortcut.
199+
spatial_norm_dim (`int`, *optional*):
200+
The dimension to use for spatial norm if it is to be used instead of group norm.
201+
pad_mode (str, defaults to `"first"`):
202+
Padding mode.
203+
"""
204+
205+
def __init__(
206+
self,
207+
in_channels: int,
208+
out_channels: Optional[int] = None,
209+
non_linearity: str = "swish",
210+
):
211+
super().__init__()
212+
213+
out_channels = out_channels or in_channels
214+
215+
self.in_channels = in_channels
216+
self.out_channels = out_channels
217+
self.nonlinearity = get_activation(non_linearity)
218+
219+
self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
220+
self.conv1 = MochiChunkedCausalConv3d(
221+
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1
222+
)
223+
self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
224+
self.conv2 = MochiChunkedCausalConv3d(
225+
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1
226+
)
227+
228+
229+
def forward(
230+
self,
231+
inputs: torch.Tensor,
232+
) -> torch.Tensor:
233+
234+
hidden_states = inputs
235+
236+
hidden_states = self.norm1(hidden_states)
237+
hidden_states = self.nonlinearity(hidden_states)
238+
hidden_states = self.conv1(hidden_states)
239+
240+
hidden_states = self.norm2(hidden_states)
241+
hidden_states = self.nonlinearity(hidden_states)
242+
hidden_states = self.conv2(hidden_states)
243+
244+
hidden_states = hidden_states + inputs
245+
return hidden_states
246+
247+
248+
class MochiUpBlock3D(nn.Module):
249+
r"""
250+
An upsampling block used in the Mochi model.
251+
252+
Args:
253+
in_channels (`int`):
254+
Number of input channels.
255+
out_channels (`int`, *optional*):
256+
Number of output channels. If None, defaults to `in_channels`.
257+
"""
258+
259+
def __init__(
260+
self,
261+
in_channels: int,
262+
out_channels: int,
263+
num_layers: int = 1,
264+
temporal_expansion: int = 2,
265+
spatial_expansion: int = 2,
266+
):
267+
super().__init__()
268+
self.temporal_expansion = temporal_expansion
269+
self.spatial_expansion = spatial_expansion
270+
271+
resnets = []
272+
for i in range(num_layers):
273+
resnets.append(
274+
MochiResnetBlock3D(
275+
in_channels=in_channels,
276+
)
277+
)
278+
279+
self.resnets = nn.ModuleList(resnets)
280+
self.proj = Conv1x1(
281+
in_channels,
282+
out_channels * temporal_expansion * (spatial_expansion**2),
283+
)
284+
285+
self.gradient_checkpointing = False
286+
287+
def forward(
288+
self,
289+
hidden_states: torch.Tensor,
290+
) -> torch.Tensor:
291+
r"""Forward method of the `MochiUpBlock3D` class."""
292+
293+
for i, resnet in enumerate(self.resnets):
294+
295+
if self.training and self.gradient_checkpointing:
296+
297+
def create_custom_forward(module):
298+
def create_forward(*inputs):
299+
return module(*inputs)
300+
301+
return create_forward
302+
303+
hidden_states = torch.utils.checkpoint.checkpoint(
304+
create_custom_forward(resnet),
305+
hidden_states,
306+
)
307+
else:
308+
hidden_states = resnet(hidden_states)
309+
310+
hidden_states = self.proj(hidden_states)
311+
312+
# Calculate new shape
313+
B, C, T, H, W = hidden_states.shape
314+
st = self.temporal_expansion
315+
sh = self.spatial_expansion
316+
sw = self.spatial_expansion
317+
new_C = C // (st * sh * sw)
318+
319+
# Reshape and permute
320+
hidden_states = hidden_states.view(B, new_C, st, sh, sw, T, H, W)
321+
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)
322+
hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw)
323+
324+
if self.temporal_expansion > 1:
325+
print(f"x: {hidden_states.shape}")
326+
# Drop the first self.temporal_expansion - 1 frames.
327+
hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :]
328+
print(f"x: {hidden_states.shape}")
329+
330+
return hidden_states
331+
332+
333+
class MochiMidBlock3D(nn.Module):
334+
r"""
335+
A middle block used in the Mochi model.
336+
337+
Args:
338+
in_channels (`int`):
339+
Number of input channels.
340+
"""
341+
342+
_supports_gradient_checkpointing = True
343+
344+
def __init__(
345+
self,
346+
in_channels: int, # 768
347+
num_layers: int = 3,
348+
):
349+
super().__init__()
350+
351+
resnets = []
352+
for _ in range(num_layers):
353+
resnets.append(
354+
MochiResnetBlock3D(in_channels=in_channels)
355+
)
356+
self.resnets = nn.ModuleList(resnets)
357+
358+
self.gradient_checkpointing = False
359+
360+
def forward(
361+
self,
362+
hidden_states: torch.Tensor,
363+
) -> torch.Tensor:
364+
r"""Forward method of the `MochiMidBlock3D` class."""
365+
366+
for i, resnet in enumerate(self.resnets):
367+
if self.training and self.gradient_checkpointing:
368+
369+
def create_custom_forward(module):
370+
def create_forward(*inputs):
371+
return module(*inputs)
372+
373+
return create_forward
374+
375+
hidden_states = torch.utils.checkpoint.checkpoint(
376+
create_custom_forward(resnet), hidden_states
377+
)
378+
else:
379+
hidden_states = resnet(hidden_states)
380+
381+
return hidden_states
382+
383+
384+
class MochiDecoder3D(nn.Module):
385+
_supports_gradient_checkpointing = True
386+
387+
def __init__(
388+
self,
389+
in_channels: int, # 12
390+
out_channels: int, # 3
391+
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
392+
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
393+
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
394+
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
395+
non_linearity: str = "swish",
396+
):
397+
super().__init__()
398+
399+
self.nonlinearity = get_activation(non_linearity)
400+
401+
self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
402+
self.block_in = MochiMidBlock3D(
403+
in_channels=block_out_channels[-1],
404+
num_layers=layers_per_block[-1],
405+
)
406+
self.up_blocks = nn.ModuleList([])
407+
for i in range(len(block_out_channels) - 1):
408+
up_block = MochiUpBlock3D(
409+
in_channels=block_out_channels[-i - 1],
410+
out_channels=block_out_channels[-i - 2],
411+
num_layers=layers_per_block[-i - 2],
412+
temporal_expansion=temporal_expansions[-i - 1],
413+
spatial_expansion=spatial_expansions[-i - 1],
414+
)
415+
self.up_blocks.append(up_block)
416+
self.block_out = MochiMidBlock3D(
417+
in_channels=block_out_channels[0],
418+
num_layers=layers_per_block[0],
419+
)
420+
self.conv_out = Conv1x1(block_out_channels[0], out_channels)
421+
422+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423+
r"""Forward method of the `MochiDecoder3D` class."""
424+
425+
print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
426+
hidden_states = self.conv_in(hidden_states)
427+
print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
428+
429+
430+
# 1. Mid
431+
hidden_states = self.block_in(hidden_states)
432+
print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
433+
# 2. Up
434+
for i, up_block in enumerate(self.up_blocks):
435+
hidden_states = up_block(hidden_states)
436+
print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
437+
# 3. Post-process
438+
hidden_states = self.block_out(hidden_states)
439+
print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
440+
441+
hidden_states = self.nonlinearity(hidden_states)
442+
hidden_states = self.conv_out(hidden_states)
443+
print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
444+
445+
return hidden_states
446+

0 commit comments

Comments
 (0)