diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index b1d1ffb63423..8e0326e0c334 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin +## KandinskyLoraLoaderMixin +[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin + ## LoraBaseMixin [[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 95d559ff758b..aa500b149441 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -220,6 +220,7 @@ "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", + "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -474,6 +475,7 @@ "ImageTextPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", + "Kandinsky5T2VPipeline", "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", @@ -912,6 +914,7 @@ HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -1136,6 +1139,7 @@ ImageTextPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, + Kandinsky5T2VPipeline, KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 742548653800..48507aae038c 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", @@ -115,6 +116,7 @@ def text_encoder_attn_modules(text_encoder): FluxLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + KandinskyLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e25a29e1c00e..2bb6c0ea026e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3639,6 +3639,291 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): super().unfuse_lora(components=components, **kwargs) +class KandinskyLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Kandinsky5Transformer3DModel`], + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* of a pretrained model hosted on the Hub. + - A path to a *directory* containing the model weights. + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + use_safetensors (`bool`, *optional*): + Whether to use safetensors for loading. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata. + """ + # Load the main state dict first which has the LoRA layers + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + hotswap (`bool`, *optional*): + Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. + kwargs (`dict`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + # Load LoRA into transformer + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + Load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. + transformer (`Kandinsky5Transformer3DModel`): + The transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights. + hotswap (`bool`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the transformer and text encoders. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process. + save_function (`Callable`): + The function to use to save the state dictionary. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError("You must pass at least one of `transformer_lora_layers`") + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. + + Example: + ```py + from diffusers import Kandinsky5T2VPipeline + + pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + pipeline.load_lora_weights("path/to/lora.safetensors") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of [`pipe.fuse_lora()`]. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components, **kwargs) + + class WanLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`]. diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..8d029bf5d31c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -91,6 +91,7 @@ _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] @@ -182,6 +183,7 @@ HunyuanDiT2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..6b80ea6c82a5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,6 +27,7 @@ from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py new file mode 100644 index 000000000000..d4ba92acaf6e --- /dev/null +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -0,0 +1,667 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import ( + logging, +) +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin + + +logger = logging.get_logger(__name__) + + +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + return freqs + + +def fractal_flatten(x, rope, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) + else: + x = x.flatten(1, 3) + rope = rope.flatten(1, 3) + return x, rope + + +def fractal_unflatten(x, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) + else: + x = x.reshape(*shape, *x.shape[2:]) + return x + + +def local_patching(x, shape, group_size, dim=0): + batch_size, duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :], + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)), + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + batch_size, duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :], + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)), + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + + +def nablaT_v2( + q: Tensor, + k: Tensor, + sta: Tensor, + thr: float = 0.9, +): + if _CAN_USE_FLEX_ATTN: + from torch.nn.attention.flex_attention import BlockMask + else: + raise ValueError("Nabla attention is not supported with this version of PyTorch") + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None) + + +class Kandinsky5TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.freqs = get_freqs(self.model_dim // 2, self.max_period) + self.in_layer = nn.Linear(model_dim, time_dim, bias=True) + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + + @torch.autocast(device_type="cuda", dtype=torch.float32) + def forward(self, time): + args = torch.outer(time, self.freqs.to(device=time.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class Kandinsky5TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim): + super().__init__() + self.in_layer = nn.Linear(text_dim, model_dim, bias=True) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class Kandinsky5VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) + + def forward(self, x): + batch_size, duration, height, width, dim = x.shape + x = ( + x.view( + batch_size, + duration // self.patch_size[0], + self.patch_size[0], + height // self.patch_size[1], + self.patch_size[1], + width // self.patch_size[2], + self.patch_size[2], + dim, + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .flatten(4, 7) + ) + return self.in_layer(x) + + +class Kandinsky5RoPE1D(nn.Module): + def __init__(self, dim, max_pos=1024, max_period=10000.0): + super().__init__() + self.max_period = max_period + self.dim = dim + self.max_pos = max_pos + freq = get_freqs(dim // 2, max_period) + pos = torch.arange(max_pos, dtype=freq.dtype) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) + + def forward(self, pos): + args = self.args[pos] + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + + +class Kandinsky5RoPE3D(nn.Module): + def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): + super().__init__() + self.axes_dims = axes_dims + self.max_pos = max_pos + self.max_period = max_period + + for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): + freq = get_freqs(axes_dim // 2, max_period) + pos = torch.arange(ax_max_pos, dtype=freq.dtype) + self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) + + def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): + batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] + args_h = self.args_1[pos[1]] / scale_factor[1] + args_w = self.args_2[pos[2]] / scale_factor[2] + + args = torch.cat( + [ + args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), + args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), + args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + ], + dim=-1, + ) + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + + +class Kandinsky5Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() + + @torch.autocast(device_type="cuda", dtype=torch.float32) + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class Kandinsky5AttnProcessor: + _attention_backend = None + _parallel_config = None + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") + + def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): + # query, key, value = self.get_qkv(x) + query = attn.to_query(hidden_states) + + if encoder_hidden_states is not None: + key = attn.to_key(encoder_hidden_states) + value = attn.to_value(encoder_hidden_states) + + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*cond_shape, attn.num_heads, -1) + value = value.reshape(*cond_shape, attn.num_heads, -1) + + else: + key = attn.to_key(hidden_states) + value = attn.to_value(hidden_states) + + shape = query.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*shape, attn.num_heads, -1) + value = value.reshape(*shape, attn.num_heads, -1) + + # query, key = self.norm_qk(query, key) + query = attn.query_norm(query.float()).type_as(query) + key = attn.key_norm(key.float()).type_as(key) + + def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) + + if rotary_emb is not None: + query = apply_rotary(query, rotary_emb).type_as(query) + key = apply_rotary(key, rotary_emb).type_as(key) + + if sparse_params is not None: + attn_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + else: + attn_mask = None + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(-2, -1) + + attn_out = attn.out_layer(hidden_states) + return attn_out + + +class Kandinsky5Attention(nn.Module, AttentionModuleMixin): + _default_processor_cls = Kandinsky5AttnProcessor + _available_processors = [ + Kandinsky5AttnProcessor, + ] + + def __init__(self, num_channels, head_dim, processor=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + sparse_params: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + sparse_params=sparse_params, + rotary_emb=rotary_emb, + **kwargs, + ) + + +class Kandinsky5FeedForward(nn.Module): + def __init__(self, dim, ff_dim): + super().__init__() + self.in_layer = nn.Linear(dim, ff_dim, bias=False) + self.activation = nn.GELU() + self.out_layer = nn.Linear(ff_dim, dim, bias=False) + + def forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + +class Kandinsky5OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) + + def forward(self, visual_embed, text_embed, time_embed): + shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + + visual_embed = ( + self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None] + ).type_as(visual_embed) + + x = self.out_layer(visual_embed) + + batch_size, duration, height, width, _ = x.shape + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(1, 2) + .flatten(2, 3) + .flatten(3, 4) + ) + return x + + +class Kandinsky5TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) + + def forward(self, x, time_embed, rope): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) + out = self.self_attention(out, rotary_emb=rope) + x = (x.float() + gate.float() * out.float()).type_as(x) + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) + out = self.feed_forward(out) + x = (x.float() + gate.float() * out.float()).type_as(x) + + return x + + +class Kandinsky5TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) + + self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) + + def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): + self_attn_params, cross_attn_params, ff_params = torch.chunk( + self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 + ) + + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) + visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + + shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) + visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) + visual_out = self.feed_forward(visual_out) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + + return visual_embed + + +class Kandinsky5Transformer3DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): + """ + A 3D Diffusion Transformer model for video-like data. + """ + + _repeated_blocks = [ + "Kandinsky5TransformerEncoderBlock", + "Kandinsky5TransformerDecoderBlock", + ] + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_visual_dim=4, + in_text_dim=3584, + in_text_dim2=768, + time_dim=512, + out_visual_dim=4, + patch_size=(1, 2, 2), + model_dim=2048, + ff_dim=5120, + num_text_blocks=2, + num_visual_blocks=32, + axes_dims=(16, 24, 24), + visual_cond=False, + attention_type: str = "regular", + attention_causal: bool = None, + attention_local: bool = None, + attention_glob: bool = None, + attention_window: int = None, + attention_P: float = None, + attention_wT: int = None, + attention_wW: int = None, + attention_wH: int = None, + attention_add_sta: bool = None, + attention_method: str = None, + ): + super().__init__() + + head_dim = sum(axes_dims) + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_cond = visual_cond + self.attention_type = attention_type + + visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim + + # Initialize embeddings + self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # Initialize positional embeddings + self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) + self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims) + + # Initialize transformer blocks + self.text_transformer_blocks = nn.ModuleList( + [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [ + Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ] + ) + + # Initialize output layer + self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, # x + encoder_hidden_states: torch.Tensor, # text_embed + timestep: torch.Tensor, # time + pooled_projections: torch.Tensor, # pooled_text_embed + visual_rope_pos: Tuple[int, int, int], + text_rope_pos: torch.LongTensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Transformer2DModelOutput, torch.FloatTensor]: + """ + Forward pass of the Kandinsky5 3D Transformer. + + Args: + hidden_states (`torch.FloatTensor`): Input visual states + encoder_hidden_states (`torch.FloatTensor`): Text embeddings + timestep (`torch.Tensor` or `float` or `int`): Current timestep + pooled_projections (`torch.FloatTensor`): Pooled text embeddings + visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE + text_rope_pos (`torch.LongTensor`): Position for text RoPE + scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE + sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention + return_dict (`bool`, optional): Whether to return a dictionary + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer + """ + x = hidden_states + text_embed = encoder_hidden_states + time = timestep + pooled_text_embed = pooled_projections + + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) + + for text_transformer_block in self.text_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) + + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal) + + for visual_transformer_block in self.visual_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) + else: + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + x = self.out_layer(visual_embed, text_embed, time_embed) + + if not return_dict: + return x + + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..c438caed571f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -382,6 +382,7 @@ "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -671,6 +672,7 @@ Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) + from .kandinsky5 import Kandinsky5T2VPipeline from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py new file mode 100644 index 000000000000..a7975bdce926 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky import Kandinsky5T2VPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py new file mode 100644 index 000000000000..a1122a82565e --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -0,0 +1,893 @@ +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Callable, Dict, List, Optional, Union + +import regex as re +import torch +from torch.nn import functional as F +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline + >>> from diffusers.utils import export_to_video + + >>> # Available models: + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers + + >>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" + >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=121, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + + >>> export_to_video(output, "output.mp4", fps=24, quality=9) + ``` +""" + + +def basic_clean(text): + """Clean text using ftfy if available and unescape HTML entities.""" + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + """Normalize whitespace in text by replacing multiple spaces with single space.""" + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + """Apply both basic cleaning and whitespace normalization to prompts.""" + text = whitespace_clean(basic_clean(text)) + return text + + +class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder (Qwen2.5-VL). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen CLIP text encoder. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. + """ + + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.prompt_template = "\n".join( + [ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ] + ) + self.prompt_template_encode_start_idx = 129 + + self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @staticmethod + def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, reducing + computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() + sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient video processing in the + transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, + H // 8, + W // 8, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, + device=device, + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params + + def _encode_prompt_qwen( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + max_sequence_length: int = 256, + dtype: Optional[torch.dtype] = None, + ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + video generation. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + full_texts = [self.prompt_template.format(p) for p in prompt] + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_sequence_length + self.prompt_template_encode_start_idx, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] + + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + + return embeds.to(dtype), cu_seqlens + + def _encode_prompt_clip( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + return pooled_embed.to(dtype) + + def encode_prompt( + self, + prompt: Union[str, List[str]], + num_videos_per_prompt: int = 1, + max_sequence_length: int = 512, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + r""" + Encodes a single prompt (positive or negative) into text encoder hidden states. + + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for video generation. + + Args: + prompt (`str` or `List[str]`): + Prompt to be encoded. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos to generate per prompt. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_videos_per_prompt + 1,) + """ + device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype + + batch_size = len(prompt) + + prompt = [prompt_clean(p) for p in prompt] + + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] + + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + dtype=dtype, + ) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_videos_per_prompt + # Qwen embeddings: repeat sequence for each video, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) + + # CLIP embeddings: repeat for each video + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, num_videos_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) + + # Repeat cumulative sequence lengths for num_videos_per_prompt + # Original cu_seqlens: [0, len1, len1+len2, ...] + # Need to repeat the differences and reconstruct for repeated prompts + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_videos_per_prompt + repeated_lengths = original_lengths.repeat_interleave( + num_videos_per_prompt + ) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, + callback_on_step_end_tensor_inputs=None, + ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Video height + width: Video width + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): + raise ValueError( + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." + ) + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + if negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Prepare initial latent variables for video generation. + + This method creates random noise latents or uses provided latents as starting point for the denoising process. + + Args: + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if self.transformer.visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [ + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + 1, + ], + dtype=latents.dtype, + device=latents.device, + ) + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + @property + def guidance_scale(self): + """Get the current guidance scale value.""" + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled.""" + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + """Get the number of denoising timesteps.""" + return self._num_timesteps + + @property + def interrupt(self): + """Check if generation has been interrupted.""" + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 121, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds_qwen: Optional[torch.Tensor] = None, + prompt_embeds_clip: Optional[torch.Tensor] = None, + negative_prompt_embeds_qwen: Optional[torch.Tensor] = None, + negative_prompt_embeds_clip: Optional[torch.Tensor] = None, + prompt_cu_seqlens: Optional[torch.Tensor] = None, + negative_prompt_cu_seqlens: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `25`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): + A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. + """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + prompt = [prompt] + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds_qwen.shape[0] + + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + if self.do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.transformer.config.in_visual_dim + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, + ) + + # 6. Prepare rope positions for positional encoding + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + visual_rope_pos = [ + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_cu_seqlens.diff().max().item(), device=device) + if negative_cu_seqlens is not None + else None + ) + + # 7. Sparse Params for efficient attention + sparse_params = self.get_sparse_params(latents, device) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) + + # Predict noise residual + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=sparse_params, + return_dict=True, + ).sample + + if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=sparse_params, + return_dict=True, + ).sample + + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) + # Compute previous sample using the scheduler + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] + + # 9. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + num_channels_latents, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape( + batch_size * num_videos_per_prompt, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, + ) + + # Normalize and decode through VAE + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = self.video_processor.postprocess_video(video, output_type=output_type) + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py new file mode 100644 index 000000000000..ed77d42a9a83 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KandinskyPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6e7d22797902..5d62709c28fd 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -918,6 +918,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Kandinsky5Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LatteTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9ed625045261..3244ef12ef87 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1247,6 +1247,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Kandinsky5T2VPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class KandinskyCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"]