Skip to content

Commit 0b76fea

Browse files
committed
up
1 parent e9e92d0 commit 0b76fea

File tree

1 file changed

+266
-4
lines changed

1 file changed

+266
-4
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_mochi.py

Lines changed: 266 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

22+
from ...configuration_utils import ConfigMixin, register_to_config
2223
from ...utils import logging
24+
from ...utils.accelerate_utils import apply_forward_hook
2325
from ..activations import get_activation
26+
from ..modeling_utils import ModelMixin
27+
from .vae import DecoderOutput
2428

2529

2630
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -175,15 +179,15 @@ def __init__(
175179
self,
176180
in_channels: int,
177181
out_channels: Optional[int] = None,
178-
non_linearity: str = "swish",
182+
act_fn: str = "swish",
179183
):
180184
super().__init__()
181185

182186
out_channels = out_channels or in_channels
183187

184188
self.in_channels = in_channels
185189
self.out_channels = out_channels
186-
self.nonlinearity = get_activation(non_linearity)
190+
self.nonlinearity = get_activation(act_fn)
187191

188192
self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
189193
self.conv1 = MochiChunkedCausalConv3d(
@@ -377,11 +381,11 @@ def __init__(
377381
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
378382
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
379383
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
380-
non_linearity: str = "swish",
384+
act_fn: str = "swish",
381385
):
382386
super().__init__()
383387

384-
self.nonlinearity = get_activation(non_linearity)
388+
self.nonlinearity = get_activation(act_fn)
385389

386390
self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
387391
self.block_in = MochiMidBlock3D(
@@ -436,3 +440,261 @@ def create_forward(*inputs):
436440
hidden_states = self.conv_out(hidden_states)
437441

438442
return hidden_states
443+
444+
445+
class AutoencoderKLMochi(ModelMixin, ConfigMixin):
446+
r"""
447+
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
448+
[Mochi 1 preview](https://github.com/genmoai/models).
449+
450+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
451+
for all models (such as downloading or saving).
452+
453+
Parameters:
454+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
455+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
456+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
457+
Tuple of block output channels.
458+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
459+
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
460+
The component-wise standard deviation of the trained latent space computed using the first batch of the
461+
training set. This is used to scale the latent space to have unit variance when training the diffusion
462+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
463+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
464+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
465+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
466+
"""
467+
468+
_supports_gradient_checkpointing = True
469+
_no_split_modules = ["MochiResnetBlock3D"]
470+
471+
@register_to_config
472+
def __init__(
473+
self,
474+
out_channels: int = 3,
475+
block_out_channels: Tuple[int] = (128, 256, 256, 512),
476+
latent_channels: int = 12,
477+
layers_per_block: int = 3,
478+
act_fn: str = "silu",
479+
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
480+
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
481+
latents_mean: Tuple[float, ...] = (
482+
-0.06730895953510081,
483+
-0.038011381506090416,
484+
-0.07477820912866141,
485+
-0.05565264470995561,
486+
0.012767231469026969,
487+
-0.04703542746246419,
488+
0.043896967884726704,
489+
-0.09346305707025976,
490+
-0.09918314763016893,
491+
-0.008729793427399178,
492+
-0.011931556316503654,
493+
-0.0321993391887285,
494+
),
495+
latents_std: Tuple[float, ...] = (
496+
0.9263795028493863,
497+
0.9248894543193766,
498+
0.9393059390890617,
499+
0.959253732819592,
500+
0.8244560132752793,
501+
0.917259975397747,
502+
0.9294154431013696,
503+
1.3720942357788521,
504+
0.881393668867029,
505+
0.9168315692124348,
506+
0.9185249279345552,
507+
0.9274757570805041,
508+
),
509+
scaling_factor: float = 1.0,
510+
):
511+
super().__init__()
512+
513+
self.decoder = MochiDecoder3D(
514+
in_channels=latent_channels,
515+
out_channels=out_channels,
516+
block_out_channels=block_out_channels,
517+
layers_per_block=layers_per_block,
518+
temporal_expansions=temporal_expansions,
519+
spatial_expansions=spatial_expansions,
520+
act_fn=act_fn,
521+
)
522+
523+
self.use_slicing = False
524+
self.use_tiling = False
525+
526+
def _set_gradient_checkpointing(self, module, value=False):
527+
if isinstance(module, MochiDecoder3D):
528+
module.gradient_checkpointing = value
529+
530+
def enable_tiling(
531+
self,
532+
tile_sample_min_height: Optional[int] = None,
533+
tile_sample_min_width: Optional[int] = None,
534+
tile_overlap_factor_height: Optional[float] = None,
535+
tile_overlap_factor_width: Optional[float] = None,
536+
) -> None:
537+
r"""
538+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
539+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
540+
processing larger images.
541+
542+
Args:
543+
tile_sample_min_height (`int`, *optional*):
544+
The minimum height required for a sample to be separated into tiles across the height dimension.
545+
tile_sample_min_width (`int`, *optional*):
546+
The minimum width required for a sample to be separated into tiles across the width dimension.
547+
tile_overlap_factor_height (`int`, *optional*):
548+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
549+
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
550+
value might cause more tiles to be processed leading to slow down of the decoding process.
551+
tile_overlap_factor_width (`int`, *optional*):
552+
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
553+
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
554+
value might cause more tiles to be processed leading to slow down of the decoding process.
555+
"""
556+
self.use_tiling = True
557+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
558+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
559+
self.tile_latent_min_height = int(
560+
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
561+
)
562+
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
563+
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
564+
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
565+
566+
def disable_tiling(self) -> None:
567+
r"""
568+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
569+
decoding in one step.
570+
"""
571+
self.use_tiling = False
572+
573+
def enable_slicing(self) -> None:
574+
r"""
575+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
576+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
577+
"""
578+
self.use_slicing = True
579+
580+
def disable_slicing(self) -> None:
581+
r"""
582+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
583+
decoding in one step.
584+
"""
585+
self.use_slicing = False
586+
587+
@apply_forward_hook
588+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
589+
"""
590+
Decode a batch of images.
591+
592+
Args:
593+
z (`torch.Tensor`): Input batch of latent vectors.
594+
return_dict (`bool`, *optional*, defaults to `True`):
595+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
596+
597+
Returns:
598+
[`~models.vae.DecoderOutput`] or `tuple`:
599+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
600+
returned.
601+
"""
602+
if self.use_slicing and z.shape[0] > 1:
603+
decoded_slices = [self.decoder(z_slice) for z_slice in z.split(1)]
604+
decoded = torch.cat(decoded_slices)
605+
else:
606+
decoded = self.decoder(z)
607+
608+
if not return_dict:
609+
return (decoded,)
610+
return DecoderOutput(sample=decoded)
611+
612+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
613+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
614+
for y in range(blend_extent):
615+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
616+
y / blend_extent
617+
)
618+
return b
619+
620+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
621+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
622+
for x in range(blend_extent):
623+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
624+
x / blend_extent
625+
)
626+
return b
627+
628+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
629+
r"""
630+
Decode a batch of images using a tiled decoder.
631+
632+
Args:
633+
z (`torch.Tensor`): Input batch of latent vectors.
634+
return_dict (`bool`, *optional*, defaults to `True`):
635+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
636+
637+
Returns:
638+
[`~models.vae.DecoderOutput`] or `tuple`:
639+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
640+
returned.
641+
"""
642+
643+
batch_size, num_channels, num_frames, height, width = z.shape
644+
645+
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
646+
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
647+
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
648+
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
649+
row_limit_height = self.tile_sample_min_height - blend_extent_height
650+
row_limit_width = self.tile_sample_min_width - blend_extent_width
651+
frame_batch_size = self.num_latent_frames_batch_size
652+
653+
# Split z into overlapping tiles and decode them separately.
654+
# The tiles have an overlap to avoid seams between tiles.
655+
rows = []
656+
for i in range(0, height, overlap_height):
657+
row = []
658+
for j in range(0, width, overlap_width):
659+
num_batches = max(num_frames // frame_batch_size, 1)
660+
conv_cache = None
661+
time = []
662+
663+
for k in range(num_batches):
664+
remaining_frames = num_frames % frame_batch_size
665+
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
666+
end_frame = frame_batch_size * (k + 1) + remaining_frames
667+
tile = z[
668+
:,
669+
:,
670+
start_frame:end_frame,
671+
i : i + self.tile_latent_min_height,
672+
j : j + self.tile_latent_min_width,
673+
]
674+
if self.post_quant_conv is not None:
675+
tile = self.post_quant_conv(tile)
676+
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
677+
time.append(tile)
678+
679+
row.append(torch.cat(time, dim=2))
680+
rows.append(row)
681+
682+
result_rows = []
683+
for i, row in enumerate(rows):
684+
result_row = []
685+
for j, tile in enumerate(row):
686+
# blend the above tile and the left tile
687+
# to the current tile and add the current tile to the result row
688+
if i > 0:
689+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
690+
if j > 0:
691+
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
692+
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
693+
result_rows.append(torch.cat(result_row, dim=4))
694+
695+
dec = torch.cat(result_rows, dim=3)
696+
697+
if not return_dict:
698+
return (dec,)
699+
700+
return DecoderOutput(sample=dec)

0 commit comments

Comments
 (0)