Skip to content

Commit e9e92d0

Browse files
committed
up
1 parent 0a6189e commit e9e92d0

File tree

1 file changed

+79
-87
lines changed

1 file changed

+79
-87
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 79 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -13,33 +13,19 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Dict, Optional, Tuple, Union
16+
from typing import Optional, Tuple, Union
1717

18-
import numpy as np
1918
import torch
2019
import torch.nn as nn
2120
import torch.nn.functional as F
2221

23-
from ...configuration_utils import ConfigMixin, register_to_config
24-
from ...loaders.single_file_model import FromOriginalModelMixin
2522
from ...utils import logging
26-
from ...utils.accelerate_utils import apply_forward_hook
2723
from ..activations import get_activation
28-
from ..downsampling import CogVideoXDownsample3D
29-
from ..modeling_outputs import AutoencoderKLOutput
30-
from ..modeling_utils import ModelMixin
31-
from ..upsampling import CogVideoXUpsample3D
32-
from .vae import DecoderOutput, DiagonalGaussianDistribution
3324

3425

3526
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3627

3728

38-
import torch
39-
import torch.nn as nn
40-
import torch.nn.functional as F
41-
42-
4329
# YiYi to-do: replace this with nn.Conv3d
4430
class Conv1x1(nn.Linear):
4531
"""*1x1 Conv implemented with a linear layer."""
@@ -60,17 +46,18 @@ def forward(self, x: torch.Tensor):
6046
x = super().forward(x)
6147
x = x.movedim(-1, 1)
6248
return x
63-
49+
6450

6551
class MochiChunkedCausalConv3d(nn.Module):
66-
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
52+
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in Mochi Model.
53+
It also supports memory-efficient chunked 3D convolutions.
6754
6855
Args:
6956
in_channels (`int`): Number of channels in the input tensor.
7057
out_channels (`int`): Number of output channels produced by the convolution.
7158
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
7259
stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution.
73-
pad_mode (`str`, defaults to `"constant"`): Padding mode.
60+
padding_mode (`str`, defaults to `"replicate"`): Padding mode.
7461
"""
7562

7663
def __init__(
@@ -88,7 +75,7 @@ def __init__(
8875
if isinstance(stride, int):
8976
stride = (stride,) * 3
9077

91-
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
78+
_, height_kernel_size, width_kernel_size = kernel_size
9279

9380
self.padding_mode = padding_mode
9481
height_pad = (height_kernel_size - 1) // 2
@@ -104,18 +91,17 @@ def __init__(
10491
padding_mode=padding_mode,
10592
)
10693

107-
108-
109-
def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
94+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
11095
time_kernel_size = self.conv.kernel_size[0]
11196
context_size = time_kernel_size - 1
11297
time_casual_padding = (0, 0, 0, 0, context_size, 0)
11398
hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode)
114-
99+
115100
# Memory-efficient chunked operation
116101
memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3
117102
# YiYI Notes: testing only!! please remove
118103
memory_count = 3
104+
# YiYI Notes: this number 2 should be a config: max_memory_chunk_size (2 is 2GB)
119105
if memory_count > 2:
120106
part_num = int(memory_count / 2) + 1
121107
num_frames = hidden_states.shape[2]
@@ -131,7 +117,7 @@ def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
131117
output_chunks.append(output_chunk) # Append each output chunk to the list
132118

133119
# Concatenate all output chunks along the temporal dimension
134-
hidden_states = torch.cat(output_chunks, dim=2)
120+
hidden_states = torch.cat(output_chunks, dim=2)
135121

136122
return hidden_states
137123
else:
@@ -140,9 +126,14 @@ def forward(self, hidden_states : torch.Tensor) -> torch.Tensor:
140126

141127
class MochiChunkedGroupNorm3D(nn.Module):
142128
r"""
143-
Group normalization applied per-frame.
129+
Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
130+
normalization.
144131
145132
Args:
133+
num_channels (int): Number of channels expected in input
134+
num_groups (int, optional): Number of groups to separate the channels into. Default: 32
135+
affine (bool, optional): If True, this module has learnable affine parameters. Default: True
136+
chunk_size (int, optional): Size of each chunk for processing. Default: 8
146137
147138
"""
148139

@@ -157,49 +148,27 @@ def __init__(
157148
self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
158149
self.chunk_size = chunk_size
159150

160-
def forward(
161-
self, x: torch.Tensor = None
162-
) -> torch.Tensor:
163-
151+
def forward(self, x: torch.Tensor = None) -> torch.Tensor:
164152
batch_size, channels, num_frames, height, width = x.shape
165153
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
166-
167-
num_chunks = (batch_size * num_frames + self.chunk_size - 1) // self.chunk_size
168-
169-
output = torch.cat(
170-
[self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)],
171-
dim=0
172-
)
154+
155+
output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0)
173156
output = output.view(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
174-
157+
175158
return output
176159

177160

178161
class MochiResnetBlock3D(nn.Module):
179162
r"""
180-
A 3D ResNet block used in the CogVideoX model.
163+
A 3D ResNet block used in the Mochi model.
181164
182165
Args:
183166
in_channels (`int`):
184167
Number of input channels.
185168
out_channels (`int`, *optional*):
186169
Number of output channels. If None, defaults to `in_channels`.
187-
dropout (`float`, defaults to `0.0`):
188-
Dropout rate.
189-
temb_channels (`int`, defaults to `512`):
190-
Number of time embedding channels.
191-
groups (`int`, defaults to `32`):
192-
Number of groups to separate the channels into for group normalization.
193-
eps (`float`, defaults to `1e-6`):
194-
Epsilon value for normalization layers.
195170
non_linearity (`str`, defaults to `"swish"`):
196171
Activation function to use.
197-
conv_shortcut (bool, defaults to `False`):
198-
Whether or not to use a convolution shortcut.
199-
spatial_norm_dim (`int`, *optional*):
200-
The dimension to use for spatial norm if it is to be used instead of group norm.
201-
pad_mode (str, defaults to `"first"`):
202-
Padding mode.
203172
"""
204173

205174
def __init__(
@@ -225,14 +194,12 @@ def __init__(
225194
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1
226195
)
227196

228-
229197
def forward(
230198
self,
231199
inputs: torch.Tensor,
232200
) -> torch.Tensor:
233-
234201
hidden_states = inputs
235-
202+
236203
hidden_states = self.norm1(hidden_states)
237204
hidden_states = self.nonlinearity(hidden_states)
238205
hidden_states = self.conv1(hidden_states)
@@ -254,6 +221,12 @@ class MochiUpBlock3D(nn.Module):
254221
Number of input channels.
255222
out_channels (`int`, *optional*):
256223
Number of output channels. If None, defaults to `in_channels`.
224+
num_layers (`int`, defaults to `1`):
225+
Number of resnet blocks in the block.
226+
temporal_expansion (`int`, defaults to `2`):
227+
Temporal expansion factor.
228+
spatial_expansion (`int`, defaults to `2`):
229+
Spatial expansion factor.
257230
"""
258231

259232
def __init__(
@@ -290,8 +263,7 @@ def forward(
290263
) -> torch.Tensor:
291264
r"""Forward method of the `MochiUpBlock3D` class."""
292265

293-
for i, resnet in enumerate(self.resnets):
294-
266+
for resnet in self.resnets:
295267
if self.training and self.gradient_checkpointing:
296268

297269
def create_custom_forward(module):
@@ -322,10 +294,8 @@ def create_forward(*inputs):
322294
hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw)
323295

324296
if self.temporal_expansion > 1:
325-
print(f"x: {hidden_states.shape}")
326297
# Drop the first self.temporal_expansion - 1 frames.
327298
hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :]
328-
print(f"x: {hidden_states.shape}")
329299

330300
return hidden_states
331301

@@ -337,22 +307,20 @@ class MochiMidBlock3D(nn.Module):
337307
Args:
338308
in_channels (`int`):
339309
Number of input channels.
310+
num_layers (`int`, defaults to `3`):
311+
Number of resnet blocks in the block.
340312
"""
341313

342-
_supports_gradient_checkpointing = True
343-
344314
def __init__(
345315
self,
346-
in_channels: int, # 768
316+
in_channels: int, # 768
347317
num_layers: int = 3,
348318
):
349319
super().__init__()
350320

351321
resnets = []
352322
for _ in range(num_layers):
353-
resnets.append(
354-
MochiResnetBlock3D(in_channels=in_channels)
355-
)
323+
resnets.append(MochiResnetBlock3D(in_channels=in_channels))
356324
self.resnets = nn.ModuleList(resnets)
357325

358326
self.gradient_checkpointing = False
@@ -363,7 +331,7 @@ def forward(
363331
) -> torch.Tensor:
364332
r"""Forward method of the `MochiMidBlock3D` class."""
365333

366-
for i, resnet in enumerate(self.resnets):
334+
for resnet in self.resnets:
367335
if self.training and self.gradient_checkpointing:
368336

369337
def create_custom_forward(module):
@@ -372,22 +340,39 @@ def create_forward(*inputs):
372340

373341
return create_forward
374342

375-
hidden_states = torch.utils.checkpoint.checkpoint(
376-
create_custom_forward(resnet), hidden_states
377-
)
343+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
378344
else:
379345
hidden_states = resnet(hidden_states)
380346

381347
return hidden_states
382348

383349

384350
class MochiDecoder3D(nn.Module):
385-
_supports_gradient_checkpointing = True
351+
r"""
352+
The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
353+
sample.
354+
355+
Args:
356+
in_channels (`int`, *optional*):
357+
The number of input channels.
358+
out_channels (`int`, *optional*):
359+
The number of output channels.
360+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
361+
The number of output channels for each block.
362+
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
363+
The number of resnet blocks for each block.
364+
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
365+
The temporal expansion factor for each of the up blocks.
366+
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
367+
The spatial expansion factor for each of the up blocks.
368+
non_linearity (`str`, *optional*, defaults to `"swish"`):
369+
The non-linearity to use in the decoder.
370+
"""
386371

387372
def __init__(
388373
self,
389-
in_channels: int, # 12
390-
out_channels: int, # 3
374+
in_channels: int, # 12
375+
out_channels: int, # 3
391376
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
392377
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
393378
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
@@ -418,29 +403,36 @@ def __init__(
418403
num_layers=layers_per_block[0],
419404
)
420405
self.conv_out = Conv1x1(block_out_channels[0], out_channels)
421-
406+
407+
self.gradient_checkpointing = False
408+
422409
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
423410
r"""Forward method of the `MochiDecoder3D` class."""
424411

425-
print(f"hidden_states: {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
426412
hidden_states = self.conv_in(hidden_states)
427-
print(f"hidden_states (after conv_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
428-
429413

430414
# 1. Mid
431-
hidden_states = self.block_in(hidden_states)
432-
print(f"hidden_states (after block_in): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
433-
# 2. Up
434-
for i, up_block in enumerate(self.up_blocks):
435-
hidden_states = up_block(hidden_states)
436-
print(f"hidden_states (after up_block {i}): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
437-
# 3. Post-process
415+
if self.training and self.gradient_checkpointing:
416+
417+
def create_custom_forward(module):
418+
def create_forward(*inputs):
419+
return module(*inputs)
420+
421+
return create_forward
422+
423+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.block_in), hidden_states)
424+
425+
for up_block in self.up_blocks:
426+
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
427+
else:
428+
hidden_states = self.block_in(hidden_states)
429+
430+
for up_block in self.up_blocks:
431+
hidden_states = up_block(hidden_states)
432+
438433
hidden_states = self.block_out(hidden_states)
439-
print(f"hidden_states (after block_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
440-
434+
441435
hidden_states = self.nonlinearity(hidden_states)
442436
hidden_states = self.conv_out(hidden_states)
443-
print(f"hidden_states (after conv_out): {hidden_states.shape}, {hidden_states[0,:3,0,:2,:2]}")
444437

445438
return hidden_states
446-

0 commit comments

Comments
 (0)