diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 87ff9b1fb81a..c0d571a5864d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -252,6 +252,8 @@ title: SparseControlNetModel title: ControlNets - sections: + - local: api/models/allegro_transformer3d + title: AllegroTransformer3DModel - local: api/models/aura_flow_transformer2d title: AuraFlowTransformer2DModel - local: api/models/cogvideox_transformer3d @@ -300,6 +302,8 @@ - sections: - local: api/models/autoencoderkl title: AutoencoderKL + - local: api/models/autoencoderkl_allegro + title: AutoencoderKLAllegro - local: api/models/autoencoderkl_cogvideox title: AutoencoderKLCogVideoX - local: api/models/asymmetricautoencoderkl @@ -318,6 +322,8 @@ sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/allegro + title: Allegro - local: api/pipelines/amused title: aMUSEd - local: api/pipelines/animatediff diff --git a/docs/source/en/api/models/allegro_transformer3d.md b/docs/source/en/api/models/allegro_transformer3d.md new file mode 100644 index 000000000000..e70026fe4bfc --- /dev/null +++ b/docs/source/en/api/models/allegro_transformer3d.md @@ -0,0 +1,30 @@ + + +# AllegroTransformer3DModel + +A Diffusion Transformer model for 3D data from [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AllegroTransformer3DModel + +vae = AllegroTransformer3DModel.from_pretrained("rhymes-ai/Allegro", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda") +``` + +## AllegroTransformer3DModel + +[[autodoc]] AllegroTransformer3DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/models/autoencoderkl_allegro.md b/docs/source/en/api/models/autoencoderkl_allegro.md new file mode 100644 index 000000000000..fd9d10d5724b --- /dev/null +++ b/docs/source/en/api/models/autoencoderkl_allegro.md @@ -0,0 +1,37 @@ + + +# AutoencoderKLAllegro + +The 3D variational autoencoder (VAE) model with KL loss used in [Allegro](https://github.com/rhymes-ai/Allegro) was introduced in [Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) by RhymesAI. + +The model can be loaded with the following code snippet. + +```python +from diffusers import AutoencoderKLAllegro + +vae = AutoencoderKLCogVideoX.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32).to("cuda") +``` + +## AutoencoderKLAllegro + +[[autodoc]] AutoencoderKLAllegro + - decode + - encode + - all + +## AutoencoderKLOutput + +[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput + +## DecoderOutput + +[[autodoc]] models.autoencoders.vae.DecoderOutput diff --git a/docs/source/en/api/pipelines/allegro.md b/docs/source/en/api/pipelines/allegro.md new file mode 100644 index 000000000000..e13e339944e5 --- /dev/null +++ b/docs/source/en/api/pipelines/allegro.md @@ -0,0 +1,34 @@ + + +# Allegro + +[Allegro: Open the Black Box of Commercial-Level Video Generation Model](https://huggingface.co/papers/2410.15458) from RhymesAI, by Yuan Zhou, Qiuyue Wang, Yuxuan Cai, Huan Yang. + +The abstract from the paper is: + +*Significant advancements have been made in the field of video generation, with the open-source community contributing a wealth of research papers and tools for training high-quality models. However, despite these efforts, the available information and resources remain insufficient for achieving commercial-level performance. In this report, we open the black box and introduce Allegro, an advanced video generation model that excels in both quality and temporal consistency. We also highlight the current limitations in the field and present a comprehensive methodology for training high-performance, commercial-level video generation models, addressing key aspects such as data, model architecture, training pipeline, and evaluation. Our user study shows that Allegro surpasses existing open-source models and most commercial models, ranking just behind Hailuo and Kling. Code: https://github.com/rhymes-ai/Allegro , Model: https://huggingface.co/rhymes-ai/Allegro , Gallery: https://rhymes.ai/allegro_gallery .* + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## AllegroPipeline + +[[autodoc]] AllegroPipeline + - all + - __call__ + +## AllegroPipelineOutput + +[[autodoc]] pipelines.allegro.pipeline_output.AllegroPipelineOutput diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 789458a26299..ff59a3839552 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -77,9 +77,11 @@ else: _import_structure["models"].extend( [ + "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", "AuraFlowTransformer2DModel", "AutoencoderKL", + "AutoencoderKLAllegro", "AutoencoderKLCogVideoX", "AutoencoderKLTemporalDecoder", "AutoencoderOobleck", @@ -237,6 +239,7 @@ else: _import_structure["pipelines"].extend( [ + "AllegroPipeline", "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", "AmusedImg2ImgPipeline", @@ -556,9 +559,11 @@ from .utils.dummy_pt_objects import * # noqa F403 else: from .models import ( + AllegroTransformer3DModel, AsymmetricAutoencoderKL, AuraFlowTransformer2DModel, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -697,6 +702,7 @@ from .utils.dummy_torch_and_transformers_objects import * # noqa F403 else: from .pipelines import ( + AllegroPipeline, AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, AmusedImg2ImgPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4dda8c36ba1c..38dd2819133d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -28,6 +28,7 @@ _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] + _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] @@ -54,6 +55,7 @@ _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] + _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -81,6 +83,7 @@ from .autoencoders import ( AsymmetricAutoencoderKL, AutoencoderKL, + AutoencoderKLAllegro, AutoencoderKLCogVideoX, AutoencoderKLTemporalDecoder, AutoencoderOobleck, @@ -97,6 +100,7 @@ from .embeddings import ImageProjection from .modeling_utils import ModelMixin from .transformers import ( + AllegroTransformer3DModel, AuraFlowTransformer2DModel, CogVideoXTransformer3DModel, CogView3PlusTransformer2DModel, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e735c4ee7d17..db88ecbbb9d3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -1521,6 +1521,100 @@ def __call__( return hidden_states, encoder_hidden_states +class AllegroAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Allegro model. It applies a normalization layer and rotary embedding on the query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "AllegroAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # Apply RoPE if needed + if image_rotary_emb is not None and not attn.is_cross_attention: + from .embeddings import apply_rotary_emb_allegro + + query = apply_rotary_emb_allegro(query, image_rotary_emb[0], image_rotary_emb[1]) + key = apply_rotary_emb_allegro(key, image_rotary_emb[0], image_rotary_emb[1]) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class AuraFlowAttnProcessor2_0: """Attention processor used typically in processing Aura Flow.""" diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index ccf4552b2a5e..9628fe7f21b0 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,5 +1,6 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL +from .autoencoder_kl_allegro import AutoencoderKLAllegro from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_oobleck import AutoencoderOobleck diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py new file mode 100644 index 000000000000..4836de7e16ab --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_kl_allegro.py @@ -0,0 +1,1155 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils.accelerate_utils import apply_forward_hook +from ..attention_processor import Attention, SpatialNorm +from ..autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from ..downsampling import Downsample2D +from ..modeling_outputs import AutoencoderKLOutput +from ..modeling_utils import ModelMixin +from ..resnet import ResnetBlock2D +from ..upsampling import Upsample2D + + +class AllegroTemporalConvLayer(nn.Module): + r""" + Temporal convolutional layer that can be used for video (sequence of images) input. Code adapted from: + https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/models/multi_modal/video_synthesis/unet_sd.py#L1016 + """ + + def __init__( + self, + in_dim: int, + out_dim: Optional[int] = None, + dropout: float = 0.0, + norm_num_groups: int = 32, + up_sample: bool = False, + down_sample: bool = False, + stride: int = 1, + ) -> None: + super().__init__() + + out_dim = out_dim or in_dim + pad_h = pad_w = int((stride - 1) * 0.5) + pad_t = 0 + + self.down_sample = down_sample + self.up_sample = up_sample + + if down_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (2, stride, stride), stride=(2, 1, 1), padding=(0, pad_h, pad_w)), + ) + elif up_sample: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim * 2, (1, stride, stride), padding=(0, pad_h, pad_w)), + ) + else: + self.conv1 = nn.Sequential( + nn.GroupNorm(norm_num_groups, in_dim), + nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv2 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_w)), + ) + self.conv3 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + self.conv4 = nn.Sequential( + nn.GroupNorm(norm_num_groups, out_dim), + nn.SiLU(), + nn.Conv3d(out_dim, in_dim, (3, stride, stride), padding=(pad_t, pad_h, pad_h)), + ) + + @staticmethod + def _pad_temporal_dim(hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = torch.cat((hidden_states[:, :, 0:1], hidden_states), dim=2) + hidden_states = torch.cat((hidden_states, hidden_states[:, :, -1:]), dim=2) + return hidden_states + + def forward(self, hidden_states: torch.Tensor, batch_size: int) -> torch.Tensor: + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + if self.down_sample: + identity = hidden_states[:, :, ::2] + elif self.up_sample: + identity = hidden_states.repeat_interleave(2, dim=2) + else: + identity = hidden_states + + if self.down_sample or self.up_sample: + hidden_states = self.conv1(hidden_states) + else: + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv1(hidden_states) + + if self.up_sample: + hidden_states = hidden_states.unflatten(1, (2, -1)).permute(0, 2, 3, 1, 4, 5).flatten(2, 3) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv2(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv3(hidden_states) + + hidden_states = self._pad_temporal_dim(hidden_states) + hidden_states = self.conv4(hidden_states) + + hidden_states = identity + hidden_states + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + return hidden_states + + +class AllegroDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_downsample: bool = True, + temporal_downsample: bool = False, + downsample_padding: int = 1, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if temporal_downsample: + self.temp_convs_down = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, down_sample=True, stride=3 + ) + self.add_temp_downsample = temporal_downsample + + if spatial_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_downsample: + hidden_states = self.temp_convs_down(hidden_states, batch_size=batch_size) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + temb_channels: Optional[int] = None, + ): + super().__init__() + + resnets = [] + temp_convs = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + AllegroTemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + self.add_temp_upsample = temporal_upsample + if temporal_upsample: + self.temp_conv_up = AllegroTemporalConvLayer( + out_channels, out_channels, dropout=0.1, norm_num_groups=resnet_groups, up_sample=True, stride=3 + ) + + if spatial_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + if self.add_temp_upsample: + hidden_states = self.temp_conv_up(hidden_states, batch_size=batch_size) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroMidBlock3DConv(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift == "default" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + temp_convs.append( + AllegroTemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size = hidden_states.shape[0] + + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.resnets[0](hidden_states, temb=None) + + hidden_states = self.temp_convs[0](hidden_states, batch_size=batch_size) + + for attn, resnet, temp_conv in zip(self.attentions, self.resnets[1:], self.temp_convs[1:]): + hidden_states = attn(hidden_states) + hidden_states = resnet(hidden_states, temb=None) + hidden_states = temp_conv(hidden_states, batch_size=batch_size) + + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return hidden_states + + +class AllegroEncoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = [True, True, False, False], + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d( + in_channels=block_out_channels[0], + out_channels=block_out_channels[0], + kernel_size=(3, 1, 1), + padding=(1, 0, 0), + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if down_block_type == "AllegroDownBlock3D": + down_block = AllegroDownBlock3D( + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + spatial_downsample=not is_final_block, + temporal_downsample=temporal_downsample_blocks[i], + resnet_eps=1e-6, + downsample_padding=0, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + ) + else: + raise ValueError("Invalid `down_block_type` encountered. Must be `AllegroDownBlock3D`") + + self.down_blocks.append(down_block) + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + + self.temp_conv_out = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # Down blocks + for down_block in self.down_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample) + + # Mid block + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + else: + # Down blocks + for down_block in self.down_blocks: + sample = down_block(sample) + + # Mid block + sample = self.mid_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AllegroDecoder3D(nn.Module): + def __init__( + self, + in_channels: int = 4, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + temporal_upsample_blocks: Tuple[bool, ...] = [False, True, True, False], + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + + self.conv_in = nn.Conv2d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.temp_conv_in = nn.Conv3d(block_out_channels[-1], block_out_channels[-1], (3, 1, 1), padding=(1, 0, 0)) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = AllegroMidBlock3DConv( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "AllegroUpBlock3D": + up_block = AllegroUpBlock3D( + num_layers=layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + spatial_upsample=not is_final_block, + temporal_upsample=temporal_upsample_blocks[i], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + resnet_time_scale_shift=norm_type, + ) + else: + raise ValueError("Invalid `UP_block_type` encountered. Must be `AllegroUpBlock3D`") + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + + self.conv_act = nn.SiLU() + + self.temp_conv_out = nn.Conv3d(block_out_channels[0], block_out_channels[0], (3, 1, 1), padding=(1, 0, 0)) + self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + batch_size = sample.shape[0] + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_in(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_in(sample) + sample = sample + residual + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # Mid block + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample) + + # Up blocks + for up_block in self.up_blocks: + sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample) + + else: + # Mid block + sample = self.mid_block(sample) + sample = sample.to(upscale_dtype) + + # Up blocks + for up_block in self.up_blocks: + sample = up_block(sample) + + # Post process + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + residual = sample + sample = self.temp_conv_out(sample) + sample = sample + residual + + sample = sample.permute(0, 2, 1, 3, 4).flatten(0, 1) + sample = self.conv_out(sample) + + sample = sample.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return sample + + +class AutoencoderKLAllegro(ModelMixin, ConfigMixin): + r""" + A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in + [Allegro](https://github.com/rhymes-ai/Allegro). + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + in_channels (int, defaults to `3`): + Number of channels in the input image. + out_channels (int, defaults to `3`): + Number of channels in the output. + down_block_types (`Tuple[str, ...]`, defaults to `("AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D", "AllegroDownBlock3D")`): + Tuple of strings denoting which types of down blocks to use. + up_block_types (`Tuple[str, ...]`, defaults to `("AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D", "AllegroUpBlock3D")`): + Tuple of strings denoting which types of up blocks to use. + block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`): + Tuple of integers denoting number of output channels in each block. + temporal_downsample_blocks (`Tuple[bool, ...]`, defaults to `(True, True, False, False)`): + Tuple of booleans denoting which blocks to enable temporal downsampling in. + latent_channels (`int`, defaults to `4`): + Number of channels in latents. + layers_per_block (`int`, defaults to `2`): + Number of resnet or attention or temporal convolution layers per down/up block. + act_fn (`str`, defaults to `"silu"`): + The activation function to use. + norm_num_groups (`int`, defaults to `32`): + Number of groups to use in normalization layers. + temporal_compression_ratio (`int`, defaults to `4`): + Ratio by which temporal dimension of samples are compressed. + sample_size (`int`, defaults to `320`): + Default latent size. + scaling_factor (`float`, defaults to `0.13235`): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + force_upcast (`bool`, default to `True`): + If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE + can be fine-tuned / trained to a lower range without loosing too much precision in which case + `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types: Tuple[str, ...] = ( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels: Tuple[int, ...] = (128, 256, 512, 512), + temporal_downsample_blocks: Tuple[bool, ...] = (True, True, False, False), + temporal_upsample_blocks: Tuple[bool, ...] = (False, True, True, False), + latent_channels: int = 4, + layers_per_block: int = 2, + act_fn: str = "silu", + norm_num_groups: int = 32, + temporal_compression_ratio: float = 4, + sample_size: int = 320, + scaling_factor: float = 0.13, + force_upcast: bool = True, + ) -> None: + super().__init__() + + self.encoder = AllegroEncoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + temporal_downsample_blocks=temporal_downsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + self.decoder = AllegroDecoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + temporal_upsample_blocks=temporal_upsample_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + # TODO(aryan): For the 1.0.0 refactor, `temporal_compression_ratio` can be inferred directly and we don't need + # to use a specific parameter here or in other VAEs. + + self.use_slicing = False + self.use_tiling = False + + self.spatial_compression_ratio = 2 ** (len(block_out_channels) - 1) + self.tile_overlap_t = 8 + self.tile_overlap_h = 120 + self.tile_overlap_w = 80 + sample_frames = 24 + + self.kernel = (sample_frames, sample_size, sample_size) + self.stride = ( + sample_frames - self.tile_overlap_t, + sample_size - self.tile_overlap_h, + sample_size - self.tile_overlap_w, + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (AllegroEncoder3D, AllegroDecoder3D)): + module.gradient_checkpointing = value + + def enable_tiling(self) -> None: + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.use_tiling = True + + def disable_tiling(self) -> None: + r""" + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_tiling = False + + def enable_slicing(self) -> None: + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self) -> None: + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + # TODO(aryan) + # if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): + if self.use_tiling: + return self.tiled_encode(x) + + raise NotImplementedError("Encoding without tiling has not been implemented yet.") + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + r""" + Encode a batch of videos into latents. + + Args: + x (`torch.Tensor`): + Input batch of videos. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded videos. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor) -> torch.Tensor: + # TODO(aryan): refactor tiling implementation + # if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height): + if self.use_tiling: + return self.tiled_decode(z) + + raise NotImplementedError("Decoding without tiling has not been implemented yet.") + + @apply_forward_hook + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + """ + Decode a batch of videos. + + Args: + z (`torch.Tensor`): + Input batch of latent vectors. + return_dict (`bool`, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z) + + if not return_dict: + return (decoded,) + return DecoderOutput(sample=decoded) + + def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + batch_size, num_channels, num_frames, height, width = x.shape + + output_num_frames = math.floor((num_frames - self.kernel[0]) / self.stride[0]) + 1 + output_height = math.floor((height - self.kernel[1]) / self.stride[1]) + 1 + output_width = math.floor((width - self.kernel[2]) / self.stride[2]) + 1 + + count = 0 + output_latent = x.new_zeros( + ( + output_num_frames * output_height * output_width, + 2 * self.config.latent_channels, + self.kernel[0] // rt, + self.kernel[1] // rs, + self.kernel[2] // rs, + ) + ) + vae_batch_input = x.new_zeros((local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2])) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + + video_cube = x[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = video_cube + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + latent = self.encoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + output_latent[count - count % local_batch_size :] = latent[: count % local_batch_size + 1] + else: + output_latent[count - local_batch_size + 1 : count + 1] = latent + + vae_batch_input = x.new_zeros( + (local_batch_size, num_channels, self.kernel[0], self.kernel[1], self.kernel[2]) + ) + + count += 1 + + latent = x.new_zeros( + (batch_size, 2 * self.config.latent_channels, num_frames // rt, height // rs, width // rs) + ) + output_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + output_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + output_overlap = ( + output_kernel[0] - output_stride[0], + output_kernel[1] - output_stride[1], + output_kernel[2] - output_stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * output_stride[0], i * output_stride[0] + output_kernel[0] + for j in range(output_height): + h_start, h_end = j * output_stride[1], j * output_stride[1] + output_kernel[1] + for k in range(output_width): + w_start, w_end = k * output_stride[2], k * output_stride[2] + output_kernel[2] + latent_mean = _prepare_for_blend( + (i, output_num_frames, output_overlap[0]), + (j, output_height, output_overlap[1]), + (k, output_width, output_overlap[2]), + output_latent[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + latent[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += latent_mean + + latent = latent.permute(0, 2, 1, 3, 4).flatten(0, 1) + latent = self.quant_conv(latent) + latent = latent.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + return latent + + def tiled_decode(self, z: torch.Tensor) -> torch.Tensor: + local_batch_size = 1 + rs = self.spatial_compression_ratio + rt = self.config.temporal_compression_ratio + + latent_kernel = self.kernel[0] // rt, self.kernel[1] // rs, self.kernel[2] // rs + latent_stride = self.stride[0] // rt, self.stride[1] // rs, self.stride[2] // rs + + batch_size, num_channels, num_frames, height, width = z.shape + + ## post quant conv (a mapping) + z = z.permute(0, 2, 1, 3, 4).flatten(0, 1) + z = self.post_quant_conv(z) + z = z.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4) + + output_num_frames = math.floor((num_frames - latent_kernel[0]) / latent_stride[0]) + 1 + output_height = math.floor((height - latent_kernel[1]) / latent_stride[1]) + 1 + output_width = math.floor((width - latent_kernel[2]) / latent_stride[2]) + 1 + + count = 0 + decoded_videos = z.new_zeros( + ( + output_num_frames * output_height * output_width, + self.config.out_channels, + self.kernel[0], + self.kernel[1], + self.kernel[2], + ) + ) + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + for i in range(output_num_frames): + for j in range(output_height): + for k in range(output_width): + n_start, n_end = i * latent_stride[0], i * latent_stride[0] + latent_kernel[0] + h_start, h_end = j * latent_stride[1], j * latent_stride[1] + latent_kernel[1] + w_start, w_end = k * latent_stride[2], k * latent_stride[2] + latent_kernel[2] + + current_latent = z[:, :, n_start:n_end, h_start:h_end, w_start:w_end] + vae_batch_input[count % local_batch_size] = current_latent + + if ( + count % local_batch_size == local_batch_size - 1 + or count == output_num_frames * output_height * output_width - 1 + ): + current_video = self.decoder(vae_batch_input) + + if ( + count == output_num_frames * output_height * output_width - 1 + and count % local_batch_size != local_batch_size - 1 + ): + decoded_videos[count - count % local_batch_size :] = current_video[ + : count % local_batch_size + 1 + ] + else: + decoded_videos[count - local_batch_size + 1 : count + 1] = current_video + + vae_batch_input = z.new_zeros( + (local_batch_size, num_channels, latent_kernel[0], latent_kernel[1], latent_kernel[2]) + ) + + count += 1 + + video = z.new_zeros((batch_size, self.config.out_channels, num_frames * rt, height * rs, width * rs)) + video_overlap = ( + self.kernel[0] - self.stride[0], + self.kernel[1] - self.stride[1], + self.kernel[2] - self.stride[2], + ) + + for i in range(output_num_frames): + n_start, n_end = i * self.stride[0], i * self.stride[0] + self.kernel[0] + for j in range(output_height): + h_start, h_end = j * self.stride[1], j * self.stride[1] + self.kernel[1] + for k in range(output_width): + w_start, w_end = k * self.stride[2], k * self.stride[2] + self.kernel[2] + out_video_blend = _prepare_for_blend( + (i, output_num_frames, video_overlap[0]), + (j, output_height, video_overlap[1]), + (k, output_width, video_overlap[2]), + decoded_videos[i * output_height * output_width + j * output_width + k].unsqueeze(0), + ) + video[:, :, n_start:n_end, h_start:h_end, w_start:w_end] += out_video_blend + + video = video.permute(0, 2, 1, 3, 4).contiguous() + return video + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + encoder_local_batch_size: int = 2, + decoder_local_batch_size: int = 2, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + PyTorch random number generator. + encoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the encoder's batch inference. + decoder_local_batch_size (`int`, *optional*, defaults to 2): + Local batch size for the decoder's batch inference. + """ + x = sample + posterior = self.encode(x, local_batch_size=encoder_local_batch_size).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, local_batch_size=decoder_local_batch_size).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + +def _prepare_for_blend(n_param, h_param, w_param, x): + # TODO(aryan): refactor + n, n_max, overlap_n = n_param + h, h_max, overlap_h = h_param + w, w_max, overlap_w = w_param + if overlap_n > 0: + if n > 0: # the head overlap part decays from 0 to 1 + x[:, :, 0:overlap_n, :, :] = x[:, :, 0:overlap_n, :, :] * ( + torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if n < n_max - 1: # the tail overlap part decays from 1 to 0 + x[:, :, -overlap_n:, :, :] = x[:, :, -overlap_n:, :, :] * ( + 1 - torch.arange(0, overlap_n).float().to(x.device) / overlap_n + ).reshape(overlap_n, 1, 1) + if h > 0: + x[:, :, :, 0:overlap_h, :] = x[:, :, :, 0:overlap_h, :] * ( + torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if h < h_max - 1: + x[:, :, :, -overlap_h:, :] = x[:, :, :, -overlap_h:, :] * ( + 1 - torch.arange(0, overlap_h).float().to(x.device) / overlap_h + ).reshape(overlap_h, 1) + if w > 0: + x[:, :, :, :, 0:overlap_w] = x[:, :, :, :, 0:overlap_w] * ( + torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + if w < w_max - 1: + x[:, :, :, :, -overlap_w:] = x[:, :, :, :, -overlap_w:] * ( + 1 - torch.arange(0, overlap_w).float().to(x.device) / overlap_w + ) + return x diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 44f01c46ebe8..66917dce6107 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -564,6 +564,42 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): return cos, sin +def get_3d_rotary_pos_embed_allegro( + embed_dim, + crops_coords, + grid_size, + temporal_size, + interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), + theta: int = 10000, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + # TODO(aryan): docs + start, stop = crops_coords + grid_size_h, grid_size_w = grid_size + interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale + grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) + grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) + grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) + + # Compute dimensions for each axis + dim_t = embed_dim // 3 + dim_h = embed_dim // 3 + dim_w = embed_dim // 3 + + # Temporal frequencies + freqs_t = get_1d_rotary_pos_embed( + dim_t, grid_t / interpolation_scale_t, theta=theta, use_real=True, repeat_interleave_real=False + ) + # Spatial frequencies for height and width + freqs_h = get_1d_rotary_pos_embed( + dim_h, grid_h / interpolation_scale_h, theta=theta, use_real=True, repeat_interleave_real=False + ) + freqs_w = get_1d_rotary_pos_embed( + dim_w, grid_w / interpolation_scale_w, theta=theta, use_real=True, repeat_interleave_real=False + ) + + return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w + + def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. @@ -684,7 +720,7 @@ def get_1d_rotary_pos_embed( freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: - # stable audio + # stable audio, allegro freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin @@ -743,6 +779,24 @@ def apply_rotary_emb( return x_out.type_as(x) +def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): + # TODO(aryan): rewrite + def apply_1d_rope(tokens, pos, cos, sin): + cos = F.embedding(pos, cos)[:, None, :, :] + sin = F.embedding(pos, sin)[:, None, :, :] + x1, x2 = tokens[..., : tokens.shape[-1] // 2], tokens[..., tokens.shape[-1] // 2 :] + tokens_rotated = torch.cat((-x2, x1), dim=-1) + return (tokens.float() * cos + tokens_rotated.float() * sin).to(tokens.dtype) + + (t_cos, t_sin), (h_cos, h_sin), (w_cos, w_sin) = freqs_cis + t, h, w = x.chunk(3, dim=-1) + t = apply_1d_rope(t, positions[0], t_cos, t_sin) + h = apply_1d_rope(h, positions[1], h_cos, h_sin) + w = apply_1d_rope(w, positions[2], w_cos, w_sin) + x = torch.cat([t, h, w], dim=-1) + return x + + class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 029c147fcbac..87dec66935da 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -22,10 +22,7 @@ from ..utils import is_torch_version from .activations import get_activation -from .embeddings import ( - CombinedTimestepLabelEmbeddings, - PixArtAlphaCombinedTimestepSizeEmbeddings, -) +from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings class AdaLayerNorm(nn.Module): @@ -266,6 +263,7 @@ def forward( hidden_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. + added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 58787c079ea8..873a2bbecf05 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -14,6 +14,7 @@ from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel + from .transformer_allegro import AllegroTransformer3DModel from .transformer_cogview3plus import CogView3PlusTransformer2DModel from .transformer_flux import FluxTransformer2DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/transformer_allegro.py b/src/diffusers/models/transformers/transformer_allegro.py new file mode 100644 index 000000000000..f756399a378a --- /dev/null +++ b/src/diffusers/models/transformers/transformer_allegro.py @@ -0,0 +1,422 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph +from ..attention import FeedForward +from ..attention_processor import AllegroAttnProcessor2_0, Attention +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle + + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class AllegroTransformerBlock(nn.Module): + r""" + Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model. + + Args: + dim (`int`): + The number of channels in the input and output. + num_attention_heads (`int`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`): + The number of channels in each head. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to be used in feed-forward. + attention_bias (`bool`, defaults to `False`): + Whether or not to use bias in attention projection layers. + only_cross_attention (`bool`, defaults to `False`): + norm_elementwise_affine (`bool`, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_eps (`float`, defaults to `1e-5`): + Epsilon value for normalization layers. + final_dropout (`bool` defaults to `False`): + Whether to apply a final dropout after the last feed-forward layer. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + attention_bias: bool = False, + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + ): + super().__init__() + + # 1. Self Attention + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + processor=AllegroAttnProcessor2_0(), + ) + + # 2. Cross Attention + self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + processor=AllegroAttnProcessor2_0(), + ) + + # 3. Feed Forward + self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + ) + + # 4. Scale-shift + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + temb: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb=None, + ) -> torch.Tensor: + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + image_rotary_emb=image_rotary_emb, + ) + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 1. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = hidden_states + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + image_rotary_emb=None, + ) + hidden_states = attn_output + hidden_states + + # 2. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + + # TODO(aryan): maybe following line is not required + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class AllegroTransformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + """ + A 3D Transformer model for video-like data. + + Args: + patch_size (`int`, defaults to `2`): + The size of spatial patches to use in the patch embedding layer. + patch_size_t (`int`, defaults to `1`): + The size of temporal patches to use in the patch embedding layer. + num_attention_heads (`int`, defaults to `24`): + The number of heads to use for multi-head attention. + attention_head_dim (`int`, defaults to `96`): + The number of channels in each head. + in_channels (`int`, defaults to `4`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `4`): + The number of channels in the output. + num_layers (`int`, defaults to `32`): + The number of layers of Transformer blocks to use. + dropout (`float`, defaults to `0.0`): + The dropout probability to use. + cross_attention_dim (`int`, defaults to `2304`): + The dimension of the cross attention features. + attention_bias (`bool`, defaults to `True`): + Whether or not to use bias in the attention projection layers. + sample_height (`int`, defaults to `90`): + The height of the input latents. + sample_width (`int`, defaults to `160`): + The width of the input latents. + sample_frames (`int`, defaults to `22`): + The number of frames in the input latents. + activation_fn (`str`, defaults to `"gelu-approximate"`): + Activation function to use in feed-forward. + norm_elementwise_affine (`bool`, defaults to `False`): + Whether or not to use elementwise affine in normalization layers. + norm_eps (`float`, defaults to `1e-6`): + The epsilon value to use in normalization layers. + caption_channels (`int`, defaults to `4096`): + Number of channels to use for projecting the caption embeddings. + interpolation_scale_h (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across height dimension. + interpolation_scale_w (`float`, defaults to `2.0`): + Scaling factor to apply in 3D positional embeddings across width dimension. + interpolation_scale_t (`float`, defaults to `2.2`): + Scaling factor to apply in 3D positional embeddings across time dimension. + """ + + @register_to_config + def __init__( + self, + patch_size: int = 2, + patch_size_t: int = 1, + num_attention_heads: int = 24, + attention_head_dim: int = 96, + in_channels: int = 4, + out_channels: int = 4, + num_layers: int = 32, + dropout: float = 0.0, + cross_attention_dim: int = 2304, + attention_bias: bool = True, + sample_height: int = 90, + sample_width: int = 160, + sample_frames: int = 22, + activation_fn: str = "gelu-approximate", + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + caption_channels: int = 4096, + interpolation_scale_h: float = 2.0, + interpolation_scale_w: float = 2.0, + interpolation_scale_t: float = 2.2, + ): + super().__init__() + + self.inner_dim = num_attention_heads * attention_head_dim + + interpolation_scale_t = ( + interpolation_scale_t + if interpolation_scale_t is not None + else ((sample_frames - 1) // 16 + 1) + if sample_frames % 2 == 1 + else sample_frames // 16 + ) + interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30 + interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40 + + # 1. Patch embedding + self.pos_embed = PatchEmbed( + height=sample_height, + width=sample_width, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=self.inner_dim, + pos_embed_type=None, + ) + + # 2. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + AllegroTransformerBlock( + self.inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + ) + for _ in range(num_layers) + ] + ) + + # 3. Output projection & norm + self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5) + self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels) + + # 4. Timestep embeddings + self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False) + + # 5. Caption projection + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + self.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + return_dict: bool = True, + ): + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t = self.config.patch_size_t + p = self.config.patch_size + + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None + if attention_mask is not None and attention_mask.ndim == 4: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + # b, frame+use_image_num, h, w -> a video with images + # b, 1, h, w -> only images + attention_mask = attention_mask.to(hidden_states.dtype) + attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width] + + if attention_mask.numel() > 0: + attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width] + attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p)) + attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1) + + attention_mask = ( + (1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None + ) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Timestep embeddings + timestep, embedded_timestep = self.adaln_single( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Patch embeddings + hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) + hidden_states = self.pos_embed(hidden_states) + hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1]) + + # 3. Transformer blocks + for i, block in enumerate(self.transformer_blocks): + # TODO(aryan): Implement gradient checkpointing + if self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + timestep, + attention_mask, + encoder_attention_mask, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=timestep, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + image_rotary_emb=image_rotary_emb, + ) + + # 4. Output normalization & projection + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # 5. Unpatchify + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.reshape(batch_size, -1, num_frames, height, width) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7366520f4692..634088f1b51a 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -116,6 +116,7 @@ "VersatileDiffusionTextToImagePipeline", ] ) + _import_structure["allegro"] = ["AllegroPipeline"] _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = [ "AnimateDiffPipeline", @@ -454,6 +455,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: + from .allegro import AllegroPipeline from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import ( AnimateDiffControlNetPipeline, diff --git a/src/diffusers/pipelines/allegro/__init__.py b/src/diffusers/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..2162b825e0a2 --- /dev/null +++ b/src/diffusers/pipelines/allegro/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_allegro"] = ["AllegroPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_allegro import AllegroPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py new file mode 100644 index 000000000000..9314960f9618 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -0,0 +1,918 @@ +# Copyright 2024 The RhymesAI and The HuggingFace Team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import T5EncoderModel, T5Tokenizer + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...models import AllegroTransformer3DModel, AutoencoderKLAllegro +from ...models.embeddings import get_3d_rotary_pos_embed_allegro +from ...pipelines.pipeline_utils import DiffusionPipeline +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from .pipeline_output import AllegroPipelineOutput + + +logger = logging.get_logger(__name__) + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AutoencoderKLAllegro, AllegroPipeline + >>> from diffusers.utils import export_to_video + + >>> vae = AutoencoderKLAllegro.from_pretrained("rhymes-ai/Allegro", subfolder="vae", torch_dtype=torch.float32) + >>> pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", vae=vae, torch_dtype=torch.bfloat16).to("cuda") + + >>> prompt = ( + ... "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, " + ... "the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this " + ... "location might be a popular spot for docking fishing boats." + ... ) + >>> video = pipe(prompt, guidance_scale=7.5, max_sequence_length=512).frames[0] + >>> export_to_video(video, "output.mp4", fps=15) + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class AllegroPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-video generation using Allegro. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AllegroAutoEncoderKL3D`]): + Variational Auto-Encoder (VAE) Model to encode and decode video to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. PixArt-Alpha uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`AllegroTransformer3DModel`]): + A text conditioned `AllegroTransformer3DModel` to denoise the encoded video latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + _callback_tensor_inputs = [ + "latents", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKLAllegro, + transformer: AllegroTransformer3DModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler + ) + self.vae_scale_factor_spatial = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) + self.vae_scale_factor_temporal = ( + self.vae.config.temporal_compression_ratio if hasattr(self, "vae") and self.vae is not None else 4 + ) + + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + # Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->512, num_images_per_prompt->num_videos_per_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 512, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_videos_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the "" + string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use for the prompt. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + max_length = max_sequence_length + + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because T5 can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt + uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to(device) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_videos_per_prompt, 1) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if num_frames <= 0: + raise ValueError(f"`num_frames` have to be positive but is {num_frames}.") + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`")) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_latents( + self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if num_frames % 2 == 0: + num_frames = math.ceil(num_frames / self.vae_scale_factor_temporal) + else: + num_frames = math.ceil((num_frames - 1) / self.vae_scale_factor_temporal) + 1 + + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: + latents = 1 / self.vae.config.scaling_factor * latents + frames = self.vae.decode(latents).sample + frames = frames.permute(0, 2, 1, 3, 4) # [batch_size, channels, num_frames, height, width] + return frames + + def _prepare_rotary_positional_embeddings( + self, + batch_size: int, + height: int, + width: int, + num_frames: int, + device: torch.device, + ): + grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) + + start, stop = (0, 0), (grid_height, grid_width) + freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w = get_3d_rotary_pos_embed_allegro( + embed_dim=self.transformer.config.attention_head_dim, + crops_coords=(start, stop), + grid_size=(grid_height, grid_width), + temporal_size=num_frames, + interpolation_scale=( + self.transformer.config.interpolation_scale_t, + self.transformer.config.interpolation_scale_h, + self.transformer.config.interpolation_scale_w, + ), + ) + + grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long) + grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long) + grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long) + + pos = torch.cartesian_prod(grid_t, grid_h, grid_w) + pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() + grid_t, grid_h, grid_w = pos + + freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device)) + freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device)) + freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device)) + + return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 100, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + num_frames: Optional[int] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_videos_per_prompt: int = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + clean_caption: bool = True, + max_sequence_length: int = 512, + ) -> Union[AllegroPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the video generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality video at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`, + usually at the expense of lower video quality. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + num_frames: (`int`, *optional*, defaults to 88): + The number controls the generated video frames. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated video. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated video. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate video. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + max_sequence_length (`int` defaults to `512`): + Maximum sequence length to use with the `prompt`. + + Examples: + + Returns: + [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.allegro.pipeline_output.AllegroPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated videos. + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + num_videos_per_prompt = 1 + + # 1. Check inputs. Raise error if not correct + num_frames = num_frames or self.transformer.config.sample_frames * self.vae_scale_factor_temporal + height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial + width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial + + self.check_inputs( + prompt, + num_frames, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + self._guidance_scale = guidance_scale + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) + if prompt_embeds.ndim == 3: + prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d + + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + latent_channels, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary embeddings + image_rotary_emb = self._prepare_rotary_positional_embeddings( + batch_size, height, width, latents.size(2), device + ) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # predict noise model_output + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if not output_type == "latent": + latents = latents.to(self.vae.dtype) + video = self.decode_latents(latents) + video = video[:, :, :num_frames, :height, :width] + video = self.video_processor.postprocess_video(video=video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return AllegroPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/allegro/pipeline_output.py b/src/diffusers/pipelines/allegro/pipeline_output.py new file mode 100644 index 000000000000..6a721783ca86 --- /dev/null +++ b/src/diffusers/pipelines/allegro/pipeline_output.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import List, Union + +import numpy as np +import PIL +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class AllegroPipelineOutput(BaseOutput): + r""" + Output class for Allegro pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]] diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 10d0399a6761..8a87b04a66cb 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroTransformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AsymmetricAutoencoderKL(metaclass=DummyObject): _backends = ["torch"] @@ -47,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderKLAllegro(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderKLCogVideoX(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9046a4f73533..83d160b08df4 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class AllegroPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AltDiffusionImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/transformers/test_models_transformer_allegro.py b/tests/models/transformers/test_models_transformer_allegro.py new file mode 100644 index 000000000000..ad8b7a3824ba --- /dev/null +++ b/tests/models/transformers/test_models_transformer_allegro.py @@ -0,0 +1,79 @@ +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import AllegroTransformer3DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class AllegroTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = AllegroTransformer3DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + batch_size = 2 + num_channels = 4 + num_frames = 2 + height = 8 + width = 8 + embedding_dim = 16 + sequence_length = 16 + + hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim // 2)).to(torch_device) + timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + } + + @property + def input_shape(self): + return (4, 2, 8, 8) + + @property + def output_shape(self): + return (4, 2, 8, 8) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + # Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings. + "num_attention_heads": 2, + "attention_head_dim": 8, + "in_channels": 4, + "out_channels": 4, + "num_layers": 1, + "cross_attention_dim": 16, + "sample_width": 8, + "sample_height": 8, + "sample_frames": 8, + "caption_channels": 8, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict diff --git a/tests/pipelines/allegro/__init__.py b/tests/pipelines/allegro/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py new file mode 100644 index 000000000000..d09fc0488378 --- /dev/null +++ b/tests/pipelines/allegro/test_allegro.py @@ -0,0 +1,337 @@ +# Copyright 2024 The HuggingFace Team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import inspect +import unittest + +import numpy as np +import torch +from transformers import AutoTokenizer, T5Config, T5EncoderModel + +from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler +from diffusers.utils.testing_utils import ( + enable_full_determinism, + numpy_cosine_similarity_distance, + require_torch_gpu, + slow, + torch_device, +) + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin, to_np + + +enable_full_determinism() + + +class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AllegroPipeline + params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = TEXT_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + test_xformers_attention = False + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = AllegroTransformer3DModel( + num_attention_heads=2, + attention_head_dim=12, + in_channels=4, + out_channels=4, + num_layers=1, + cross_attention_dim=24, + sample_width=8, + sample_height=8, + sample_frames=8, + caption_channels=24, + ) + + torch.manual_seed(0) + vae = AutoencoderKLAllegro( + in_channels=3, + out_channels=3, + down_block_types=( + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + "AllegroDownBlock3D", + ), + up_block_types=( + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + "AllegroUpBlock3D", + ), + block_out_channels=(8, 8, 8, 8), + latent_channels=4, + layers_per_block=1, + norm_num_groups=2, + temporal_compression_ratio=4, + ) + + # TODO(aryan): Only for now, since VAE decoding without tiling is not yet implemented here + vae.enable_tiling() + + torch.manual_seed(0) + scheduler = DDIMScheduler() + + text_encoder_config = T5Config( + **{ + "d_ff": 37, + "d_kv": 8, + "d_model": 24, + "num_decoder_layers": 2, + "num_heads": 4, + "num_layers": 2, + "relative_attention_num_buckets": 8, + "vocab_size": 1103, + } + ) + text_encoder = T5EncoderModel(text_encoder_config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") + + components = { + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + inputs = { + "prompt": "dance monkey", + "negative_prompt": "", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "height": 16, + "width": 16, + "num_frames": 8, + "max_sequence_length": 16, + "output_type": "pt", + } + + return inputs + + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_local(self): + pass + + @unittest.skip("Decoding without tiling is not yet implemented") + def test_save_load_optional_components(self): + pass + + def test_inference(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + video = pipe(**inputs).frames + generated_video = video[0] + + self.assertEqual(generated_video.shape, (8, 3, 16, 16)) + expected_video = torch.randn(8, 3, 16, 16) + max_diff = np.abs(generated_video - expected_video).max() + self.assertLessEqual(max_diff, 1e10) + + def test_callback_inputs(self): + sig = inspect.signature(self.pipeline_class.__call__) + has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters + has_callback_step_end = "callback_on_step_end" in sig.parameters + + if not (has_callback_tensor_inputs and has_callback_step_end): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + self.assertTrue( + hasattr(pipe, "_callback_tensor_inputs"), + f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs", + ) + + def callback_inputs_subset(pipe, i, t, callback_kwargs): + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + def callback_inputs_all(pipe, i, t, callback_kwargs): + for tensor_name in pipe._callback_tensor_inputs: + assert tensor_name in callback_kwargs + + # iterate over callback args + for tensor_name, tensor_value in callback_kwargs.items(): + # check that we're only passing in allowed tensor inputs + assert tensor_name in pipe._callback_tensor_inputs + + return callback_kwargs + + inputs = self.get_dummy_inputs(torch_device) + + # Test passing in a subset + inputs["callback_on_step_end"] = callback_inputs_subset + inputs["callback_on_step_end_tensor_inputs"] = ["latents"] + output = pipe(**inputs)[0] + + # Test passing in a everything + inputs["callback_on_step_end"] = callback_inputs_all + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + + def callback_inputs_change_tensor(pipe, i, t, callback_kwargs): + is_last = i == (pipe.num_timesteps - 1) + if is_last: + callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"]) + return callback_kwargs + + inputs["callback_on_step_end"] = callback_inputs_change_tensor + inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs + output = pipe(**inputs)[0] + assert output.abs().sum() < 1e10 + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3) + + def test_attention_slicing_forward_pass( + self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3 + ): + if not self.test_attention_slicing: + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + output_without_slicing = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=1) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing1 = pipe(**inputs)[0] + + pipe.enable_attention_slicing(slice_size=2) + inputs = self.get_dummy_inputs(generator_device) + output_with_slicing2 = pipe(**inputs)[0] + + if test_max_difference: + max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max() + max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max() + self.assertLess( + max(max_diff1, max_diff2), + expected_max_diff, + "Attention slicing should not affect the inference results", + ) + + # TODO(aryan) + @unittest.skip("Decoding without tiling is not yet implemented.") + def test_vae_tiling(self, expected_diff_max: float = 0.2): + generator_device = "cpu" + components = self.get_dummy_components() + + pipe = self.pipeline_class(**components) + pipe.to("cpu") + pipe.set_progress_bar_config(disable=None) + + # Without tiling + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_without_tiling = pipe(**inputs)[0] + + # With tiling + pipe.vae.enable_tiling( + tile_sample_min_height=96, + tile_sample_min_width=96, + tile_overlap_factor_height=1 / 12, + tile_overlap_factor_width=1 / 12, + ) + inputs = self.get_dummy_inputs(generator_device) + inputs["height"] = inputs["width"] = 128 + output_with_tiling = pipe(**inputs)[0] + + self.assertLess( + (to_np(output_without_tiling) - to_np(output_with_tiling)).max(), + expected_diff_max, + "VAE tiling should not affect the inference results", + ) + + +@slow +@require_torch_gpu +class AllegroPipelineIntegrationTests(unittest.TestCase): + prompt = "A painting of a squirrel eating a burger." + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def test_allegro(self): + generator = torch.Generator("cpu").manual_seed(0) + + pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16) + pipe.enable_model_cpu_offload() + prompt = self.prompt + + videos = pipe( + prompt=prompt, + height=720, + width=1280, + num_frames=88, + generator=generator, + num_inference_steps=2, + output_type="pt", + ).frames + + video = videos[0] + expected_video = torch.randn(1, 88, 720, 1280, 3).numpy() + + max_diff = numpy_cosine_similarity_distance(video, expected_video) + assert max_diff < 1e-3, f"Max diff is too high. got {video}" diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 3e6f9d1278e8..295a94c1d2e4 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1103,6 +1103,7 @@ def _test_inference_batch_consistent( logger.setLevel(level=diffusers.logging.WARNING) for batch_size, batched_input in zip(batch_sizes, batched_inputs): + print(batch_size, batched_input) output = pipe(**batched_input) assert len(output[0]) == batch_size