From d53f848720a03423bb9998e75a30b4c3cd04e96d Mon Sep 17 00:00:00 2001 From: leffff Date: Sat, 4 Oct 2025 10:10:23 +0000 Subject: [PATCH 01/70] add transformer pipeline first version --- src/diffusers/__init__.py | 4 + src/diffusers/loaders/__init__.py | 2 + src/diffusers/loaders/lora_pipeline.py | 288 +++++++- src/diffusers/models/__init__.py | 2 + src/diffusers/models/transformers/__init__.py | 1 + .../transformers/transformer_kandinsky.py | 630 ++++++++++++++++++ src/diffusers/pipelines/__init__.py | 2 + .../pipelines/kandinsky5/__init__.py | 48 ++ .../kandinsky5/pipeline_kandinsky.py | 545 +++++++++++++++ .../pipelines/kandinsky5/pipeline_output.py | 20 + 10 files changed, 1541 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/models/transformers/transformer_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/__init__.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py create mode 100644 src/diffusers/pipelines/kandinsky5/pipeline_output.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8867250deda8..19670053a3c5 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -260,6 +260,7 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", + "Kandinsky5Transformer3DModel", "attention_backend", ] ) @@ -618,6 +619,7 @@ "WanPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", + "Kandinsky5T2VPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", @@ -947,6 +949,7 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1275,6 +1278,7 @@ WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline, + Kandinsky5T2VPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 742548653800..6a48ac1b0deb 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder): "SanaLoraLoaderMixin", "Lumina2LoraLoaderMixin", "WanLoraLoaderMixin", + "KandinskyLoraLoaderMixin", "HiDreamImageLoraLoaderMixin", "SkyReelsV2LoraLoaderMixin", "QwenImageLoraLoaderMixin", @@ -126,6 +127,7 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, + KandinskyLoraLoaderMixin ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index e25a29e1c00e..ea1b92c68b59 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3638,6 +3638,292 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) + +class KandinskyLoraLoaderMixin(LoraBaseMixin): + r""" + Load LoRA layers into [`Kandinsky5Transformer3DModel`], + """ + + _lora_loadable_modules = ["transformer"] + transformer_name = TRANSFORMER_NAME + + @classmethod + @validate_hf_hub_args + def lora_state_dict( + cls, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + **kwargs, + ): + r""" + Return state dict for lora weights and the network alphas. + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + Can be either: + - A string, the *model id* of a pretrained model hosted on the Hub. + - A path to a *directory* containing the model weights. + - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory where a downloaded pretrained model configuration is cached. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint. + local_files_only (`bool`, *optional*, defaults to `False`): + Whether to only load local model weights and configuration files. + token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. + subfolder (`str`, *optional*, defaults to `""`): + The subfolder location of a model file within a larger model repository. + weight_name (`str`, *optional*, defaults to None): + Name of the serialized state dict file. + use_safetensors (`bool`, *optional*): + Whether to use safetensors for loading. + return_lora_metadata (`bool`, *optional*, defaults to False): + When enabled, additionally return the LoRA adapter metadata. + """ + # Load the main state dict first which has the LoRA layers + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + weight_name = kwargs.pop("weight_name", None) + use_safetensors = kwargs.pop("use_safetensors", None) + return_lora_metadata = kwargs.pop("return_lora_metadata", False) + + allow_pickle = False + if use_safetensors is None: + use_safetensors = True + allow_pickle = True + + user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} + + state_dict, metadata = _fetch_state_dict( + pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, + weight_name=weight_name, + use_safetensors=use_safetensors, + local_files_only=local_files_only, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + allow_pickle=allow_pickle, + ) + + is_dora_scale_present = any("dora_scale" in k for k in state_dict) + if is_dora_scale_present: + warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." + logger.warning(warn_msg) + state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} + + out = (state_dict, metadata) if return_lora_metadata else state_dict + return out + + def load_lora_weights( + self, + pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], + adapter_name: Optional[str] = None, + hotswap: bool = False, + **kwargs, + ): + """ + Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` + + Parameters: + pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + hotswap (`bool`, *optional*): + Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + kwargs (`dict`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. + """ + if not USE_PEFT_BACKEND: + raise ValueError("PEFT backend is required for this method.") + + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # if a dict is passed, copy it instead of modifying it inplace + if isinstance(pretrained_model_name_or_path_or_dict, dict): + pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() + + # First, ensure that the checkpoint is a compatible one and can be successfully loaded. + kwargs["return_lora_metadata"] = True + state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) + + is_correct_format = all("lora" in key for key in state_dict.keys()) + if not is_correct_format: + raise ValueError("Invalid LoRA checkpoint.") + + # Load LoRA into transformer + self.load_lora_into_transformer( + state_dict, + transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=self, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + @classmethod + def load_lora_into_transformer( + cls, + state_dict, + transformer, + adapter_name=None, + _pipeline=None, + low_cpu_mem_usage=False, + hotswap: bool = False, + metadata=None, + ): + """ + Load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. + transformer (`Kandinsky5Transformer3DModel`): + The transformer model to load the LoRA layers into. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. + low_cpu_mem_usage (`bool`, *optional*): + Speed up model loading by only loading the pretrained LoRA weights. + hotswap (`bool`, *optional*): + See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`]. + metadata (`dict`): + Optional LoRA adapter metadata. + """ + if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): + raise ValueError( + "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." + ) + + # Load the layers corresponding to transformer. + logger.info(f"Loading {cls.transformer_name}.") + transformer.load_lora_adapter( + state_dict, + network_alphas=None, + adapter_name=adapter_name, + metadata=metadata, + _pipeline=_pipeline, + low_cpu_mem_usage=low_cpu_mem_usage, + hotswap=hotswap, + ) + + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + transformer_lora_adapter_metadata=None, + ): + r""" + Save the LoRA parameters corresponding to the transformer and text encoders. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to save LoRA parameters to. + transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): + State dict of the LoRA layers corresponding to the `transformer`. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process. + save_function (`Callable`): + The function to use to save the state dictionary. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way. + transformer_lora_adapter_metadata: + LoRA adapter metadata associated with the transformer. + """ + lora_layers = {} + lora_metadata = {} + + if transformer_lora_layers: + lora_layers[cls.transformer_name] = transformer_lora_layers + lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata + + if not lora_layers: + raise ValueError( + "You must pass at least one of `transformer_lora_layers`" + ) + + cls._save_lora_weights( + save_directory=save_directory, + lora_layers=lora_layers, + lora_metadata=lora_metadata, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def fuse_lora( + self, + components: List[str] = ["transformer"], + lora_scale: float = 1.0, + safe_fusing: bool = False, + adapter_names: Optional[List[str]] = None, + **kwargs, + ): + r""" + Fuses the LoRA parameters into the original parameters of the corresponding blocks. + + Args: + components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into. + lora_scale (`float`, defaults to 1.0): + Controls how much to influence the outputs with the LoRA parameters. + safe_fusing (`bool`, defaults to `False`): + Whether to check fused weights for NaN values before fusing. + adapter_names (`List[str]`, *optional*): + Adapter names to be used for fusing. + + Example: + ```py + from diffusers import Kandinsky5T2VPipeline + + pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + pipeline.load_lora_weights("path/to/lora.safetensors") + pipeline.fuse_lora(lora_scale=0.7) + ``` + """ + super().fuse_lora( + components=components, + lora_scale=lora_scale, + safe_fusing=safe_fusing, + adapter_names=adapter_names, + **kwargs, + ) + + def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): + r""" + Reverses the effect of [`pipe.fuse_lora()`]. + + Args: + components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. + """ + super().unfuse_lora(components=components, **kwargs) + class WanLoraLoaderMixin(LoraBaseMixin): r""" @@ -4802,4 +5088,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 457f70448af3..89ca9d39774b 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -101,6 +101,7 @@ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -200,6 +201,7 @@ TransformerTemporalModel, WanTransformer3DModel, WanVACETransformer3DModel, + Kandinsky5Transformer3DModel, ) from .unets import ( I2VGenXLUNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index b60f0636e6dc..4b9911f9cb5d 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -37,3 +37,4 @@ from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py new file mode 100644 index 000000000000..a057cc13cc0f --- /dev/null +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -0,0 +1,630 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Dict, Optional, Tuple, Union, List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils.torch_utils import maybe_allow_in_graph +from .._modeling_parallel import ContextParallelInput, ContextParallelOutput +from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +from ..attention_dispatch import dispatch_attention_fn +from ..cache_utils import CacheMixin +from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import FP32LayerNorm + +logger = logging.get_logger(__name__) + + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_scale_shift_norm(norm, x, scale, shift): + return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", dtype=torch.float32) +def apply_gate_sum(x, out, gate): + return (x + gate * out).to(torch.bfloat16) + +# @torch.compile() +@torch.autocast(device_type="cuda", enabled=False) +def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) + + +@torch.autocast(device_type="cuda", enabled=False) +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=dim, dtype=torch.float32) + / dim + ) + return freqs + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer( + "freqs", get_freqs(model_dim // 2, max_period), persistent=False + ) + self.in_layer = nn.Linear(model_dim, time_dim, bias=True) + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + + def forward(self, time): + args = torch.outer(time, self.freqs.to(device=time.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim): + super().__init__() + self.in_layer = nn.Linear(text_dim, model_dim, bias=True) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=True) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) + + def forward(self, x): + batch_size, duration, height, width, dim = x.shape + x = ( + x.view( + batch_size, + duration // self.patch_size[0], + self.patch_size[0], + height // self.patch_size[1], + self.patch_size[1], + width // self.patch_size[2], + self.patch_size[2], + dim, + ) + .permute(0, 1, 3, 5, 2, 4, 6, 7) + .flatten(4, 7) + ) + return self.in_layer(x) + + +class RoPE1D(nn.Module): + """ + 1D Rotary Positional Embeddings for text sequences. + + Args: + dim: Dimension of the rotary embeddings + max_pos: Maximum sequence length + max_period: Maximum period for sinusoidal embeddings + """ + + def __init__(self, dim, max_pos=1024, max_period=10000.0): + super().__init__() + self.max_period = max_period + self.dim = dim + self.max_pos = max_pos + freq = get_freqs(dim // 2, max_period) + pos = torch.arange(max_pos, dtype=freq.dtype) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) + + def forward(self, pos): + """ + Args: + pos: Position indices of shape [seq_len] or [batch_size, seq_len] + + Returns: + Rotary embeddings of shape [seq_len, 1, 2, 2] + """ + args = self.args[pos] + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + + +class RoPE3D(nn.Module): + def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): + super().__init__() + self.axes_dims = axes_dims + self.max_pos = max_pos + self.max_period = max_period + + for i, (axes_dim, ax_max_pos) in enumerate(zip(axes_dims, max_pos)): + freq = get_freqs(axes_dim // 2, max_period) + pos = torch.arange(ax_max_pos, dtype=freq.dtype) + self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) + + @torch.autocast(device_type="cuda", enabled=False) + def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): + batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] + args_h = self.args_1[pos[1]] / scale_factor[1] + args_w = self.args_2[pos[2]] / scale_factor[2] + + # Replicate the original logic with batch dimension + args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) + args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) + args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) + + # Concatenate along the last dimension + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + + cosine = torch.cos(args) + sine = torch.sin(args) + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] + rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] + return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class MultiheadSelfAttentionEnc(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use torch's scaled_dot_product_attention + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadSelfAttentionDec(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, rope, sparse_params=None): + query = self.to_query(x) + key = self.to_key(x) + value = self.to_value(x) + + shape = query.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*shape, self.num_heads, -1) + value = value.reshape(*shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + query = apply_rotary(query, rope).type_as(query) + key = apply_rotary(key, rope).type_as(key) + + # Use standard attention (can be extended with sparse attention) + out = F.scaled_dot_product_attention( + query, + key, + value, + ).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class MultiheadCrossAttention(nn.Module): + def __init__(self, num_channels, head_dim): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + + self.to_query = nn.Linear(num_channels, num_channels, bias=True) + self.to_key = nn.Linear(num_channels, num_channels, bias=True) + self.to_value = nn.Linear(num_channels, num_channels, bias=True) + self.query_norm = nn.RMSNorm(head_dim) + self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + def forward(self, x, cond): + query = self.to_query(x) + key = self.to_key(cond) + value = self.to_value(cond) + + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, self.num_heads, -1) + key = key.reshape(*cond_shape, self.num_heads, -1) + value = value.reshape(*cond_shape, self.num_heads, -1) + + query = self.query_norm(query.float()).type_as(query) + key = self.key_norm(key.float()).type_as(key) + + out = F.scaled_dot_product_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = self.out_layer(out) + return out + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim): + super().__init__() + self.in_layer = nn.Linear(dim, ff_dim, bias=False) + self.activation = nn.GELU() + self.out_layer = nn.Linear(ff_dim, dim, bias=False) + + def forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, x, time_embed, rope): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + out = self.self_attention_norm(x) + out = out * (scale + 1.0) + shift + out = self.self_attention(out, rope) + x = x + gate * out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + out = self.feed_forward_norm(x) + out = out * (scale + 1.0) + shift + out = self.feed_forward(out) + x = x + gate * out + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9) + + self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim) + + self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.cross_attention = MultiheadCrossAttention(model_dim, head_dim) + + self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.feed_forward = FeedForward(model_dim, ff_dim) + + def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): + self_attn_params, cross_attn_params, ff_params = torch.chunk( + self.visual_modulation(time_embed), 3, dim=-1 + ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) + + visual_out = self.self_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) + visual_out = self.cross_attention_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.cross_attention(visual_out, text_embed) + visual_embed = visual_embed + gate * visual_out + + shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) + visual_out = self.feed_forward_norm(visual_embed) + visual_out = visual_out * (scale + 1.0) + shift + visual_out = self.feed_forward(visual_out) + visual_embed = visual_embed + gate * visual_out + return visual_embed + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear( + model_dim, math.prod(patch_size) * visual_dim, bias=True + ) + + def forward(self, visual_embed, text_embed, time_embed): + # Handle the new batch dimension: [batch, duration, height, width, model_dim] + batch_size, duration, height, width, _ = visual_embed.shape + + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + + # Apply modulation with proper broadcasting for the new shape + visual_embed = apply_scale_shift_norm( + self.norm, + visual_embed, + scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] + ).type_as(visual_embed) + + x = self.out_layer(visual_embed) + + # Now x has shape [batch, duration, height, width, patch_prod * visual_dim] + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features] + .flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features] + .flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features] + .flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width] + ) + return x + + +@maybe_allow_in_graph +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): + r""" + A 3D Transformer model for video generation used in Kandinsky 5.0. + + This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods implemented for all models (such as downloading or saving). + + Args: + in_visual_dim (`int`, defaults to 16): + Number of channels in the input visual latent. + out_visual_dim (`int`, defaults to 16): + Number of channels in the output visual latent. + time_dim (`int`, defaults to 512): + Dimension of the time embeddings. + patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): + Patch size for the visual embeddings (temporal, height, width). + model_dim (`int`, defaults to 1792): + Hidden dimension of the transformer model. + ff_dim (`int`, defaults to 7168): + Intermediate dimension of the feed-forward networks. + num_text_blocks (`int`, defaults to 2): + Number of transformer blocks in the text encoder. + num_visual_blocks (`int`, defaults to 32): + Number of transformer blocks in the visual decoder. + axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`): + Dimensions for the rotary positional embeddings (temporal, height, width). + visual_cond (`bool`, defaults to `True`): + Whether to use visual conditioning (for image/video conditioning). + in_text_dim (`int`, defaults to 3584): + Dimension of the text embeddings from Qwen2.5-VL. + in_text_dim2 (`int`, defaults to 768): + Dimension of the pooled text embeddings from CLIP. + """ + + @register_to_config + def __init__( + self, + in_visual_dim: int = 16, + out_visual_dim: int = 16, + time_dim: int = 512, + patch_size: Tuple[int, int, int] = (1, 2, 2), + model_dim: int = 1792, + ff_dim: int = 7168, + num_text_blocks: int = 2, + num_visual_blocks: int = 32, + axes_dims: Tuple[int, int, int] = (16, 24, 24), + visual_cond: bool = True, + in_text_dim: int = 3584, + in_text_dim2: int = 768, + ): + super().__init__() + + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_cond = visual_cond + + # Calculate head dimension for attention + head_dim = sum(axes_dims) + + # Determine visual embedding dimension based on conditioning + visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim + + # 1. Embedding layers + self.time_embeddings = TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # 2. Rotary positional embeddings + self.text_rope_embeddings = RoPE1D(head_dim) + self.visual_rope_embeddings = RoPE3D(axes_dims) + + # 3. Transformer blocks + self.text_transformer_blocks = nn.ModuleList([ + TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_text_blocks) + ]) + + self.visual_transformer_blocks = nn.ModuleList([ + TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ]) + + # 4. Output layer + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + pooled_text_embed: torch.Tensor, + timestep: torch.Tensor, + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + Forward pass of the Kandinsky 5.0 3D Transformer. + + Args: + hidden_states (`torch.Tensor`): + Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`. + encoder_hidden_states (`torch.Tensor`): + Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`. + pooled_text_embed (`torch.Tensor`): + Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`. + timestep (`torch.Tensor`): + Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`. + visual_rope_pos (`List[torch.Tensor]`): + List of tensors for visual rotary positional embeddings [temporal, height, width]. + text_rope_pos (`torch.Tensor`): + Tensor for text rotary positional embeddings. + scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`): + Scale factors for rotary positional embeddings. + sparse_params (`Dict[str, Any]`, *optional*): + Parameters for sparse attention. + return_dict (`bool`, defaults to `True`): + Whether to return a dictionary or a tensor. + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: + If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned, + otherwise a `tuple` where the first element is the sample tensor. + """ + batch_size, num_frames, height, width, channels = hidden_states.shape + + # 1. Process text embeddings + text_embed = self.text_embeddings(encoder_hidden_states) + time_embed = self.time_embeddings(timestep) + + # Add pooled text embedding to time embedding + pooled_embed = self.pooled_text_embeddings(pooled_text_embed) + time_embed = time_embed + pooled_embed + + # visual_embed shape: [batch_size, seq_len, model_dim] + visual_embed = self.visual_embeddings(hidden_states) + + # 3. Text rotary embeddings + text_rope = self.text_rope_embeddings(text_rope_pos) + + # 4. Text transformer blocks + for text_block in self.text_transformer_blocks: + if self.gradient_checkpointing and self.training: + text_embed = torch.utils.checkpoint.checkpoint( + text_block, text_embed, time_embed, text_rope, use_reentrant=False + ) + else: + text_embed = text_block(text_embed, time_embed, text_rope) + + # 5. Prepare visual rope + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + + # 6. Visual transformer blocks + for visual_block in self.visual_transformer_blocks: + if self.gradient_checkpointing and self.training: + visual_embed = torch.utils.checkpoint.checkpoint( + visual_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + # visual_rope_flat, + sparse_params, + use_reentrant=False, + ) + else: + visual_embed = visual_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + # 7. Output projection + visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) + output = self.out_layer(visual_embed, text_embed, time_embed) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 190c7871d270..201d92afb07c 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -382,6 +382,7 @@ "WuerstchenPriorPipeline", ] _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline", "WanVACEPipeline"] + _import_structure["kandinsky5"] = ["Kandinsky5T2VPipeline"] _import_structure["skyreels_v2"] = [ "SkyReelsV2DiffusionForcingPipeline", "SkyReelsV2DiffusionForcingImageToVideoPipeline", @@ -787,6 +788,7 @@ ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline + from .kandinsky5 import Kandinsky5T2VPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py new file mode 100644 index 000000000000..af8e12421740 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_kandinsky import Kandinsky5T2VPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py new file mode 100644 index 000000000000..02eae1363303 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -0,0 +1,545 @@ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): + r""" + Pipeline for text-to-video generation using Kandinsky 5.0. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + transformer ([`Kandinsky5Transformer3DModel`]): + Conditional Transformer to denoise the encoded video latents. + vae ([`AutoencoderKLHunyuanVideo`]): + Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. + text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): + Frozen text-encoder (Qwen2.5-VL). + tokenizer ([`AutoProcessor`]): + Tokenizer for Qwen2.5-VL. + text_encoder_2 ([`CLIPTextModel`]): + Frozen CLIP text encoder. + tokenizer_2 ([`CLIPTokenizer`]): + Tokenizer for CLIP. + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + transformer: Kandinsky5Transformer3DModel, + vae: AutoencoderKLHunyuanVideo, + text_encoder: Qwen2_5_VLForConditionalGeneration, + tokenizer: Qwen2VLProcessor, + text_encoder_2: CLIPTextModel, + tokenizer_2: CLIPTokenizer, + scheduler: FlowMatchEulerDiscreteScheduler, + ): + super().__init__() + + self.register_modules( + transformer=transformer, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + text_encoder_2=text_encoder_2, + tokenizer_2=tokenizer_2, + scheduler=scheduler, + ) + + self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio + self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + + def _encode_prompt_qwen( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + max_sequence_length: int = 256, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + # Kandinsky specific prompt template + prompt_template = "\n".join([ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ]) + crop_start = 129 + + full_texts = [prompt_template.format(p) for p in prompt] + + inputs = self.tokenizer( + text=full_texts, + images=None, + videos=None, + max_length=max_sequence_length + crop_start, + truncation=True, + return_tensors="pt", + padding=True, + ).to(device) + + with torch.no_grad(): + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + + attention_mask = inputs["attention_mask"][:, crop_start:] + embeds = embeds[attention_mask.bool()] + cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) + cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + + # duplicate for each generation per prompt + batch_size = len(prompt) + seq_len = embeds.shape[0] // batch_size + embeds = embeds.reshape(batch_size, seq_len, -1) + embeds = embeds.repeat(1, num_videos_per_prompt, 1) + embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + return embeds, cu_seqlens + + def _encode_prompt_clip( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + + inputs = self.tokenizer_2( + prompt, + max_length=77, + truncation=True, + add_special_tokens=True, + padding="max_length", + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + + # duplicate for each generation per prompt + batch_size = len(prompt) + pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) + pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) + + return pooled_embed + + def encode_prompt( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + do_classifier_free_guidance: bool = True, + num_videos_per_prompt: int = 1, + device: Optional[torch.device] = None, + ): + device = device or self._execution_device + + # Encode with Qwen2.5-VL + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt, device, num_videos_per_prompt + ) + pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + + if do_classifier_free_guidance: + negative_prompt = negative_prompt or "" + negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( + negative_prompt, device, num_videos_per_prompt + ) + negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + else: + negative_prompt_embeds = None + negative_pooled_embed = None + negative_cu_seqlens = None + + text_embeds = { + "text_embeds": prompt_embeds, + "pooled_embed": pooled_embed, + } + negative_text_embeds = { + "text_embeds": negative_prompt_embeds, + "pooled_embed": negative_pooled_embed, + } if do_classifier_free_guidance else None + + return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int = 16, + height: int = 480, + width: int = 832, + num_frames: int = 81, + visual_cond: bool = False, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if latents is not None: + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if visual_cond: + # For visual conditioning, concatenate with zeros and mask + visual_cond = torch.zeros_like(latents) + visual_cond_mask = torch.zeros( + [batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1], + dtype=latents.dtype, + device=latents.device + ) + latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) + + return latents + + def get_velocity( + self, + latents: torch.Tensor, + timestep: torch.Tensor, + text_embeds: Dict[str, torch.Tensor], + negative_text_embeds: Optional[Dict[str, torch.Tensor]], + visual_rope_pos: List[torch.Tensor], + text_rope_pos: torch.Tensor, + negative_text_rope_pos: torch.Tensor, + guidance_scale: float, + sparse_params: Optional[Dict] = None, + ): + # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) + + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep * 1000, # Scale to match training + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), # From Kandinsky config + sparse_params=sparse_params, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep * 1000, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=sparse_params, + return_dict=False + )[0] + + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + return pred_velocity + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]], + negative_prompt: Optional[Union[str, List[str]]] = None, + height: int = 512, + width: int = 768, + num_frames: int = 25, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + scheduler_scale: float = 10.0, + num_videos_per_prompt: int = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the video generation. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to avoid during video generation. + height (`int`, defaults to `512`): + The height in pixels of the generated video. + width (`int`, defaults to `768`): + The width in pixels of the generated video. + num_frames (`int`, defaults to `25`): + The number of frames in the generated video. + num_inference_steps (`int`, defaults to `50`): + The number of denoising steps. + guidance_scale (`float`, defaults to `5.0`): + Guidance scale as defined in classifier-free guidance. + scheduler_scale (`float`, defaults to `10.0`): + Scale factor for the custom flow matching scheduler. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of videos to generate per prompt. + generator (`torch.Generator`, *optional*): + A torch generator to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated video. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`KandinskyPipelineOutput`]. + callback_on_step_end (`Callable`, *optional*): + A function that is called at the end of each denoising step. + + Examples: + + Returns: + [`~KandinskyPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where + the first element is a list with the generated images and the second element is a list of `bool`s + indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + """ + + # 1. Check inputs + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + # 2. Define call parameters + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + + device = self._execution_device + do_classifier_free_guidance = guidance_scale > 1.0 + + + if num_frames % self.vae_scale_factor_temporal != 1: + logger.warning( + f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." + ) + num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 + num_frames = max(num_frames, 1) + + + # 3. Encode input prompt + text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + num_videos_per_prompt=num_videos_per_prompt, + device=device, + ) + + # 4. Prepare timesteps (Kandinsky uses custom flow matching) + timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) + timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) + + # 5. Prepare latent variables + num_channels_latents = 16 + latents = self.prepare_latents( + batch_size=batch_size * num_videos_per_prompt, + num_channels_latents=16, + height=height, + width=width, + num_frames=num_frames, + visual_cond=self.transformer.visual_cond, + dtype=self.transformer.dtype, + device=device, + generator=generator, + latents=latents, + ) + + # 6. Prepare rope positions + visual_rope_pos = [ + torch.arange(num_frames // 4 + 1, device=device), + torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(width // 8 // 2, device=device), + ] + + text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + + negative_text_rope_pos = ( + torch.arange(negative_cu_seqlens[-1].item(), device=device) + if negative_cu_seqlens is not None + else None + ) + + # 7. Prepare sparse attention params if needed + sparse_params = None # Can be extended based on Kandinsky attention config + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): + # Expand timestep to match batch size + time = timestep.unsqueeze(0) + + pred_velocity = self.get_velocity( + latents, + time, + text_embeds, + negative_text_embeds, + visual_rope_pos, + text_rope_pos, + negative_text_rope_pos, + guidance_scale, + sparse_params, + ) + + # Update latents using flow matching + latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + latents = callback_outputs.pop("latents", latents) + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + progress_bar.update() + + latents = latents[:, :, :, :, :16] + + # 9. Decode latents to video + if output_type != "latent": + latents = latents.to(self.vae.dtype) + # Reshape and normalize latents + video = latents.reshape( + batch_size, + num_videos_per_prompt, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // 8, + width // 8, + 16, + ) + video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] + video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + + # Normalize and decode + video = video / self.vae.config.scaling_factor + video = self.vae.decode(video).sample + video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) + + # Convert to output format + if output_type == "pil": + if num_frames == 1: + # Single image + video = [ToPILImage()(frame.squeeze(1)) for frame in video] + else: + # Video frames + video = [video[i] for i in range(video.shape[0])] + else: + video = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video,) + + return KandinskyPipelineOutput(frames=video) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_output.py b/src/diffusers/pipelines/kandinsky5/pipeline_output.py new file mode 100644 index 000000000000..ed77d42a9a83 --- /dev/null +++ b/src/diffusers/pipelines/kandinsky5/pipeline_output.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +import torch + +from diffusers.utils import BaseOutput + + +@dataclass +class KandinskyPipelineOutput(BaseOutput): + r""" + Output class for Wan pipelines. + + Args: + frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]): + List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing + denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape + `(batch_size, num_frames, channels, height, width)`. + """ + + frames: torch.Tensor From 7db6093c539b84450bbc683193b75c91cfc599e3 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 6 Oct 2025 12:43:04 +0000 Subject: [PATCH 02/70] updates --- .../transformers/transformer_kandinsky.py | 125 ++++++++----- .../kandinsky5/pipeline_kandinsky.py | 171 +++++++----------- 2 files changed, 144 insertions(+), 152 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index a057cc13cc0f..cca83988a762 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -35,6 +35,23 @@ logger = logging.get_logger(__name__) +if torch.cuda.get_device_capability()[0] >= 9: + try: + from flash_attn_interface import flash_attn_func as FA + except: + FA = None + + try: + from flash_attn import flash_attn_func as FA + except: + FA = None +else: + try: + from flash_attn import flash_attn_func as FA + except: + FA = None + + # @torch.compile() @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): @@ -99,7 +116,7 @@ def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) - + def forward(self, x): batch_size, duration, height, width, dim = x.shape x = ( @@ -107,7 +124,7 @@ def forward(self, x): batch_size, duration // self.patch_size[0], self.patch_size[0], - height // self.patch_size[1], + height // self.patch_size[1], self.patch_size[1], width // self.patch_size[2], self.patch_size[2], @@ -169,24 +186,23 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape + args_t = self.args_0[pos[0]] / scale_factor[0] args_h = self.args_1[pos[1]] / scale_factor[1] args_w = self.args_2[pos[2]] / scale_factor[2] - # Replicate the original logic with batch dimension args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) - # Concatenate along the last dimension - args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) # [B, D, H, W, F] + args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) cosine = torch.cos(args) sine = torch.sin(args) - rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) # [B, D, H, W, F, 4] - rope = rope.view(*rope.shape[:-1], 2, 2) # [B, D, H, W, F, 2, 2] - return rope.unsqueeze(-4) # [B, D, H, 1, W, F, 2, 2] - + rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) + rope = rope.view(*rope.shape[:-1], 2, 2) + return rope.unsqueeze(-4) + class Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): @@ -230,11 +246,14 @@ def forward(self, x, rope): key = apply_rotary(key, rope).type_as(key) # Use torch's scaled_dot_product_attention - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE") + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -270,11 +289,15 @@ def forward(self, x, rope, sparse_params=None): key = apply_rotary(key, rope).type_as(key) # Use standard attention (can be extended with sparse attention) - out = F.scaled_dot_product_attention( - query, - key, - value, - ).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -306,11 +329,15 @@ def forward(self, x, cond): query = self.query_norm(query.float()).type_as(query) key = self.key_norm(key.float()).type_as(key) - out = F.scaled_dot_product_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - ).permute(0, 2, 1, 3).flatten(-2, -1) + # out = F.scaled_dot_product_attention( + # query.permute(0, 2, 1, 3), + # key.permute(0, 2, 1, 3), + # value.permute(0, 2, 1, 3), + # ).permute(0, 2, 1, 3).flatten(-2, -1) + + # print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE") + + out = FA(q=query, k=key, v=value).flatten(-2, -1) out = self.out_layer(out) return out @@ -339,19 +366,18 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - out = self.self_attention_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) - x = x + gate * out + x = apply_gate_sum(x, out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = self.feed_forward_norm(x) - out = out * (scale + 1.0) + shift + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) out = self.feed_forward(out) - x = x + gate * out + x = apply_gate_sum(x, out, gate) return x @@ -371,26 +397,22 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( - self.visual_modulation(time_embed), 3, dim=-1 + self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - - visual_out = self.self_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = self.cross_attention_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = self.feed_forward_norm(visual_embed) - visual_out = visual_out * (scale + 1.0) + shift + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) visual_out = self.feed_forward(visual_out) - visual_embed = visual_embed + gate * visual_out + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed @@ -575,7 +597,7 @@ def forward( # 1. Process text embeddings text_embed = self.text_embeddings(encoder_hidden_states) time_embed = self.time_embeddings(timestep) - + # Add pooled text embedding to time embedding pooled_embed = self.pooled_text_embeddings(pooled_text_embed) time_embed = time_embed + pooled_embed @@ -587,22 +609,29 @@ def forward( text_rope = self.text_rope_embeddings(text_rope_pos) # 4. Text transformer blocks + i = 0 for text_block in self.text_transformer_blocks: if self.gradient_checkpointing and self.training: text_embed = torch.utils.checkpoint.checkpoint( text_block, text_embed, time_embed, text_rope, use_reentrant=False ) + else: text_embed = text_block(text_embed, time_embed, text_rope) + i += 1 + # 5. Prepare visual rope visual_shape = visual_embed.shape[:-1] visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + + # visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) + # visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) + visual_embed = visual_embed.flatten(1, 3) + visual_rope = visual_rope.flatten(1, 3) - visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) - visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) - # 6. Visual transformer blocks + i = 0 for visual_block in self.visual_transformer_blocks: if self.gradient_checkpointing and self.training: visual_embed = torch.utils.checkpoint.checkpoint( @@ -619,6 +648,8 @@ def forward( visual_embed = visual_block( visual_embed, text_embed, time_embed, visual_rope, sparse_params ) + + i += 1 # 7. Output projection visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 02eae1363303..9dbf31fea960 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -220,19 +220,14 @@ def encode_prompt( ): device = device or self._execution_device - # Encode with Qwen2.5-VL - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt, device, num_videos_per_prompt - ) + prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) if do_classifier_free_guidance: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen( - negative_prompt, device, num_videos_per_prompt - ) + negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) else: negative_prompt_embeds = None @@ -264,23 +259,25 @@ def prepare_latents( latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - return latents.to(device=device, dtype=dtype) - - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, - ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." + num_latent_frames = latents.shape[1] + latents = latents.to(device=device, dtype=dtype) + + else: + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) if visual_cond: # For visual conditioning, concatenate with zeros and mask @@ -294,50 +291,6 @@ def prepare_latents( return latents - def get_velocity( - self, - latents: torch.Tensor, - timestep: torch.Tensor, - text_embeds: Dict[str, torch.Tensor], - negative_text_embeds: Optional[Dict[str, torch.Tensor]], - visual_rope_pos: List[torch.Tensor], - text_rope_pos: torch.Tensor, - negative_text_rope_pos: torch.Tensor, - guidance_scale: float, - sparse_params: Optional[Dict] = None, - ): - # print(latents.shape, text_embeds["text_embeds"].shape, text_embeds["pooled_embed"].shape, timestep.shape, [el.shape for el in visual_rope_pos], text_rope_pos, sparse_params) - - pred_velocity = self.transformer( - latents, - text_embeds["text_embeds"], - text_embeds["pooled_embed"], - timestep * 1000, # Scale to match training - visual_rope_pos, - text_rope_pos, - scale_factor=(1, 2, 2), # From Kandinsky config - sparse_params=sparse_params, - return_dict=False - )[0] - - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - latents, - negative_text_embeds["text_embeds"], - negative_text_embeds["pooled_embed"], - timestep * 1000, - visual_rope_pos, - negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=sparse_params, - return_dict=False - )[0] - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) - - return pred_velocity @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) @@ -402,11 +355,9 @@ def __call__( indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ - # 1. Check inputs if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - # 2. Define call parameters if isinstance(prompt, str): batch_size = 1 else: @@ -415,16 +366,18 @@ def __call__( device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) + + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - - # 3. Encode input prompt text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, @@ -433,11 +386,6 @@ def __call__( device=device, ) - # 4. Prepare timesteps (Kandinsky uses custom flow matching) - timesteps = torch.linspace(1, 0, num_inference_steps + 1, device=device) - timesteps = scheduler_scale * timesteps / (1 + (scheduler_scale - 1) * timesteps) - - # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( batch_size=batch_size * num_videos_per_prompt, @@ -451,11 +399,12 @@ def __call__( generator=generator, latents=latents, ) + + visual_cond = latents[:, :, :, :, 16:] - # 6. Prepare rope positions visual_rope_pos = [ torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), # patch size 2 + torch.arange(height // 8 // 2, device=device), torch.arange(width // 8 // 2, device=device), ] @@ -467,31 +416,43 @@ def __call__( else None ) - # 7. Prepare sparse attention params if needed - sparse_params = None # Can be extended based on Kandinsky attention config - - # 8. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, (timestep, timestep_diff) in enumerate(zip(timesteps[:-1], torch.diff(timesteps))): - # Expand timestep to match batch size - time = timestep.unsqueeze(0) - - pred_velocity = self.get_velocity( - latents, - time, - text_embeds, - negative_text_embeds, - visual_rope_pos, - text_rope_pos, - negative_text_rope_pos, - guidance_scale, - sparse_params, - ) - - # Update latents using flow matching - latents[:, :, :, :, :16] = latents[:, :, :, :, :16] + timestep_diff * pred_velocity + for i, t in enumerate(timesteps): + timestep = t.unsqueeze(0) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + # print(latents.shape) + pred_velocity = self.transformer( + latents, + text_embeds["text_embeds"], + text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] + + if guidance_scale > 1.0 and negative_text_embeds is not None: + uncond_pred_velocity = self.transformer( + latents, + negative_text_embeds["text_embeds"], + negative_text_embeds["pooled_embed"], + timestep, + visual_rope_pos, + negative_text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=False + )[0] + + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) + + latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + latents = torch.cat([latents, visual_cond], dim=-1) if callback_on_step_end is not None: callback_kwargs = {} @@ -499,8 +460,8 @@ def __call__( callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) latents = callback_outputs.pop("latents", latents) - - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % 1 == 0): + + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() latents = latents[:, :, :, :, :16] @@ -524,7 +485,6 @@ def __call__( video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format if output_type == "pil": if num_frames == 1: @@ -533,6 +493,7 @@ def __call__( else: # Video frames video = [video[i] for i in range(video.shape[0])] + else: video = latents From a0cf07f7e086b73a49b46e2e87d0ebb10056dcd4 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 9 Oct 2025 15:09:50 +0000 Subject: [PATCH 03/70] fix 5sec generation --- .../transformers/transformer_kandinsky.py | 660 +++++++++--------- .../kandinsky5/pipeline_kandinsky.py | 51 +- 2 files changed, 368 insertions(+), 343 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index cca83988a762..3bbb9421f7ce 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -13,21 +13,27 @@ # limitations under the License. import math -from typing import Any, Dict, Optional, Tuple, Union, List +from typing import Any, Dict, List, Optional, Tuple, Union +from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F +from torch import BoolTensor, IntTensor, Tensor, nn +from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, + flex_attention) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, + unscale_lora_layers) from ...utils.torch_utils import maybe_allow_in_graph from .._modeling_parallel import ContextParallelInput, ContextParallelOutput from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward from ..attention_dispatch import dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed +from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding, + Timesteps, get_1d_rotary_pos_embed) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -35,34 +41,129 @@ logger = logging.get_logger(__name__) -if torch.cuda.get_device_capability()[0] >= 9: - try: - from flash_attn_interface import flash_attn_func as FA - except: - FA = None - - try: - from flash_attn import flash_attn_func as FA - except: - FA = None -else: - try: - from flash_attn import flash_attn_func as FA - except: - FA = None - - -# @torch.compile() +def exist(item): + return item is not None + + +def freeze(model): + for p in model.parameters(): + p.requires_grad = False + return model + + +@torch.autocast(device_type="cuda", enabled=False) +def get_freqs(dim, max_period=10000.0): + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=dim, dtype=torch.float32) + / dim + ) + return freqs + + +def fractal_flatten(x, rope, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) + else: + x = x.flatten(1, 3) + rope = rope.flatten(1, 3) + return x, rope + + +def fractal_unflatten(x, shape, block_mask=False): + if block_mask: + pixel_size = 8 + x = x.reshape(-1, pixel_size**2, *x.shape[1:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + else: + x = x.reshape(*shape, *x.shape[2:]) + return x + + +def local_patching(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + + +def sdpa(q, k, v): + query = q.transpose(1, 2).contiguous() + key = k.transpose(1, 2).contiguous() + value = v.transpose(1, 2).contiguous() + out = ( + F.scaled_dot_product_attention( + query, + key, + value + ) + .transpose(1, 2) + .contiguous() + ) + return out + + @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) -# @torch.compile() + @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_gate_sum(x, out, gate): return (x + gate * out).to(torch.bfloat16) -# @torch.compile() + @torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) @@ -70,16 +171,6 @@ def apply_rotary(x, rope): return x_out.reshape(*x.shape).to(torch.bfloat16) -@torch.autocast(device_type="cuda", enabled=False) -def get_freqs(dim, max_period=10000.0): - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=dim, dtype=torch.float32) - / dim - ) - return freqs - - class TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() @@ -93,12 +184,16 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) + @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) return time_embed + def reset_dtype(self): + self.freqs = get_freqs(self.model_dim // 2, self.max_period) + class TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): @@ -116,7 +211,7 @@ def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size self.in_layer = nn.Linear(math.prod(patch_size) * visual_dim, model_dim) - + def forward(self, x): batch_size, duration, height, width, dim = x.shape x = ( @@ -124,7 +219,7 @@ def forward(self, x): batch_size, duration // self.patch_size[0], self.patch_size[0], - height // self.patch_size[1], + height // self.patch_size[1], self.patch_size[1], width // self.patch_size[2], self.patch_size[2], @@ -137,15 +232,6 @@ def forward(self, x): class RoPE1D(nn.Module): - """ - 1D Rotary Positional Embeddings for text sequences. - - Args: - dim: Dimension of the rotary embeddings - max_pos: Maximum sequence length - max_period: Maximum period for sinusoidal embeddings - """ - def __init__(self, dim, max_pos=1024, max_period=10000.0): super().__init__() self.max_period = max_period @@ -153,22 +239,21 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): self.max_pos = max_pos freq = get_freqs(dim // 2, max_period) pos = torch.arange(max_pos, dtype=freq.dtype) - self.register_buffer("args", torch.outer(pos, freq), persistent=False) + self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) + @torch.autocast(device_type="cuda", enabled=False) def forward(self, pos): - """ - Args: - pos: Position indices of shape [seq_len] or [batch_size, seq_len] - - Returns: - Rotary embeddings of shape [seq_len, 1, 2, 2] - """ args = self.args[pos] cosine = torch.cos(args) sine = torch.sin(args) rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) + + def reset_dtype(self): + freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) + pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) + self.args = torch.outer(pos, freq) class RoPE3D(nn.Module): @@ -186,22 +271,29 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape - args_t = self.args_0[pos[0]] / scale_factor[0] args_h = self.args_1[pos[1]] / scale_factor[1] args_w = self.args_2[pos[2]] / scale_factor[2] - args_t_expanded = args_t.view(1, duration, 1, 1, -1).expand(batch_size, -1, height, width, -1) - args_h_expanded = args_h.view(1, 1, height, 1, -1).expand(batch_size, duration, -1, width, -1) - args_w_expanded = args_w.view(1, 1, 1, width, -1).expand(batch_size, duration, height, -1, -1) - - args = torch.cat([args_t_expanded, args_h_expanded, args_w_expanded], dim=-1) - + args = torch.cat( + [ + args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), + args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), + args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + ], + dim=-1, + ) cosine = torch.cos(args) sine = torch.sin(args) rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) + + def reset_dtype(self): + for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): + freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) + pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) + setattr(self, f'args_{i}', torch.outer(pos, freq)) class Modulation(nn.Module): @@ -212,10 +304,11 @@ def __init__(self, time_dim, model_dim, num_params): self.out_layer.weight.data.zero_() self.out_layer.bias.data.zero_() + @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): return self.out_layer(self.activation(x)) - + class MultiheadSelfAttentionEnc(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() @@ -227,9 +320,10 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, rope): + def get_qkv(self, x): query = self.to_query(x) key = self.to_key(x) value = self.to_value(x) @@ -239,26 +333,31 @@ def forward(self, x, rope): key = key.reshape(*shape, self.num_heads, -1) value = value.reshape(*shape, self.num_heads, -1) - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k + + def scaled_dot_product_attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + def forward(self, x, rope): + query, key, value = self.get_qkv(x) + query, key = self.norm_qk(query, key) query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - # Use torch's scaled_dot_product_attention - # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionEnc SHAPE") - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - out = FA(q=query, k=key, v=value).flatten(-2, -1) + out = self.scaled_dot_product_attention(query, key, value) - out = self.out_layer(out) + out = self.out_l(out) return out - class MultiheadSelfAttentionDec(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() @@ -270,9 +369,10 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, rope, sparse_params=None): + def get_qkv(self, x): query = self.to_query(x) key = self.to_key(x) value = self.to_value(x) @@ -282,24 +382,29 @@ def forward(self, x, rope, sparse_params=None): key = key.reshape(*shape, self.num_heads, -1) value = value.reshape(*shape, self.num_heads, -1) - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k + + def attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + def forward(self, x, rope, sparse_params=None): + query, key, value = self.get_qkv(x) + query, key = self.norm_qk(query, key) query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - # Use standard attention (can be extended with sparse attention) - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - # print(query.shape, key.shape, value.shape, "QKV MultiheadSelfAttentionDec SHAPE") - - out = FA(q=query, k=key, v=value).flatten(-2, -1) + out = self.attention(query, key, value) - out = self.out_layer(out) + out = self.out_l(out) return out @@ -314,32 +419,39 @@ def __init__(self, num_channels, head_dim): self.to_value = nn.Linear(num_channels, num_channels, bias=True) self.query_norm = nn.RMSNorm(head_dim) self.key_norm = nn.RMSNorm(head_dim) + self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - def forward(self, x, cond): + def get_qkv(self, x, cond): query = self.to_query(x) key = self.to_key(cond) value = self.to_value(cond) - + shape, cond_shape = query.shape[:-1], key.shape[:-1] query = query.reshape(*shape, self.num_heads, -1) key = key.reshape(*cond_shape, self.num_heads, -1) value = value.reshape(*cond_shape, self.num_heads, -1) - - query = self.query_norm(query.float()).type_as(query) - key = self.key_norm(key.float()).type_as(key) - - # out = F.scaled_dot_product_attention( - # query.permute(0, 2, 1, 3), - # key.permute(0, 2, 1, 3), - # value.permute(0, 2, 1, 3), - # ).permute(0, 2, 1, 3).flatten(-2, -1) - - # print(query.shape, key.shape, value.shape, "QKV MultiheadCrossAttention SHAPE") - out = FA(q=query, k=key, v=value).flatten(-2, -1) + return query, key, value + + def norm_qk(self, q, k): + q = self.query_norm(q.float()).type_as(q) + k = self.key_norm(k.float()).type_as(k) + return q, k - out = self.out_layer(out) + def attention(self, query, key, value): + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + def out_l(self, x): + return self.out_layer(x) + + def forward(self, x, cond): + query, key, value = self.get_qkv(x, cond) + query, key = self.norm_qk(query, key) + + out = self.attention(query, key, value) + out = self.out_l(out) return out @@ -354,6 +466,48 @@ def forward(self, x): return self.out_layer(self.activation(self.in_layer(x))) +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2) + self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) + self.out_layer = nn.Linear( + model_dim, math.prod(patch_size) * visual_dim, bias=True + ) + + def forward(self, visual_embed, text_embed, time_embed): + shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + visual_embed = apply_scale_shift_norm( + self.norm, + visual_embed, + scale[:, None, None], + shift[:, None, None], + ).type_as(visual_embed) + x = self.out_layer(visual_embed) + + batch_size, duration, height, width, _ = x.shape + x = ( + x.view( + batch_size, + duration, + height, + width, + -1, + self.patch_size[0], + self.patch_size[1], + self.patch_size[2], + ) + .permute(0, 1, 5, 2, 6, 3, 7, 4) + .flatten(1, 2) + .flatten(2, 3) + .flatten(3, 4) + ) + return x + + + + class TransformerEncoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -366,9 +520,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk( - self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) @@ -416,246 +568,116 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): return visual_embed -class OutLayer(nn.Module): - def __init__(self, model_dim, time_dim, visual_dim, patch_size): - super().__init__() - self.patch_size = patch_size - self.modulation = Modulation(time_dim, model_dim, 2) - self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.out_layer = nn.Linear( - model_dim, math.prod(patch_size) * visual_dim, bias=True - ) - - def forward(self, visual_embed, text_embed, time_embed): - # Handle the new batch dimension: [batch, duration, height, width, model_dim] - batch_size, duration, height, width, _ = visual_embed.shape - - shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) - - # Apply modulation with proper broadcasting for the new shape - visual_embed = apply_scale_shift_norm( - self.norm, - visual_embed, - scale[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] - shift[:, None, None, None], # [batch, 1, 1, 1, model_dim] -> [batch, 1, 1, 1] - ).type_as(visual_embed) - - x = self.out_layer(visual_embed) - - # Now x has shape [batch, duration, height, width, patch_prod * visual_dim] - x = ( - x.view( - batch_size, - duration, - height, - width, - -1, - self.patch_size[0], - self.patch_size[1], - self.patch_size[2], - ) - .permute(0, 5, 1, 6, 2, 7, 3, 4) # [batch, patch_t, duration, patch_h, height, patch_w, width, features] - .flatten(1, 2) # [batch, patch_t * duration, height, patch_w, width, features] - .flatten(2, 3) # [batch, patch_t * duration, patch_h * height, width, features] - .flatten(3, 4) # [batch, patch_t * duration, patch_h * height, patch_w * width] - ) - return x - - -@maybe_allow_in_graph class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): - r""" - A 3D Transformer model for video generation used in Kandinsky 5.0. - - This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods implemented for all models (such as downloading or saving). - - Args: - in_visual_dim (`int`, defaults to 16): - Number of channels in the input visual latent. - out_visual_dim (`int`, defaults to 16): - Number of channels in the output visual latent. - time_dim (`int`, defaults to 512): - Dimension of the time embeddings. - patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`): - Patch size for the visual embeddings (temporal, height, width). - model_dim (`int`, defaults to 1792): - Hidden dimension of the transformer model. - ff_dim (`int`, defaults to 7168): - Intermediate dimension of the feed-forward networks. - num_text_blocks (`int`, defaults to 2): - Number of transformer blocks in the text encoder. - num_visual_blocks (`int`, defaults to 32): - Number of transformer blocks in the visual decoder. - axes_dims (`Tuple[int]`, defaults to `(16, 24, 24)`): - Dimensions for the rotary positional embeddings (temporal, height, width). - visual_cond (`bool`, defaults to `True`): - Whether to use visual conditioning (for image/video conditioning). - in_text_dim (`int`, defaults to 3584): - Dimension of the text embeddings from Qwen2.5-VL. - in_text_dim2 (`int`, defaults to 768): - Dimension of the pooled text embeddings from CLIP. """ - + A 3D Diffusion Transformer model for video-like data. + """ + @register_to_config def __init__( self, - in_visual_dim: int = 16, - out_visual_dim: int = 16, - time_dim: int = 512, - patch_size: Tuple[int, int, int] = (1, 2, 2), - model_dim: int = 1792, - ff_dim: int = 7168, - num_text_blocks: int = 2, - num_visual_blocks: int = 32, - axes_dims: Tuple[int, int, int] = (16, 24, 24), - visual_cond: bool = True, - in_text_dim: int = 3584, - in_text_dim2: int = 768, + in_visual_dim=4, + in_text_dim=3584, + in_text_dim2=768, + time_dim=512, + out_visual_dim=4, + patch_size=(1, 2, 2), + model_dim=2048, + ff_dim=5120, + num_text_blocks=2, + num_visual_blocks=32, + axes_dims=(16, 24, 24), + visual_cond=False, ): super().__init__() - + + head_dim = sum(axes_dims) self.in_visual_dim = in_visual_dim self.model_dim = model_dim self.patch_size = patch_size self.visual_cond = visual_cond - # Calculate head dimension for attention - head_dim = sum(axes_dims) - - # Determine visual embedding dimension based on conditioning visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - - # 1. Embedding layers self.time_embeddings = TimeEmbeddings(model_dim, time_dim) self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) - # 2. Rotary positional embeddings self.text_rope_embeddings = RoPE1D(head_dim) - self.visual_rope_embeddings = RoPE3D(axes_dims) - - # 3. Transformer blocks - self.text_transformer_blocks = nn.ModuleList([ - TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_text_blocks) - ]) + self.text_transformer_blocks = nn.ModuleList( + [ + TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_text_blocks) + ] + ) - self.visual_transformer_blocks = nn.ModuleList([ - TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_visual_blocks) - ]) + self.visual_rope_embeddings = RoPE3D(axes_dims) + self.visual_transformer_blocks = nn.ModuleList( + [ + TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + for _ in range(num_visual_blocks) + ] + ) - # 4. Output layer self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) - self.gradient_checkpointing = False + def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, + text_rope_pos): + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) + return text_embed, time_embed, text_rope, visual_embed + + def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor, + sparse_params): + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, + block_mask=to_fractal) + return visual_embed, visual_shape, to_fractal, visual_rope + + def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + x = self.out_layer(visual_embed, text_embed, time_embed) + return x def forward( self, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - pooled_text_embed: torch.Tensor, - timestep: torch.Tensor, - visual_rope_pos: List[torch.Tensor], - text_rope_pos: torch.Tensor, - scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), - sparse_params: Optional[Dict[str, Any]] = None, - return_dict: bool = True, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - """ - Forward pass of the Kandinsky 5.0 3D Transformer. - - Args: - hidden_states (`torch.Tensor`): - Input visual latent tensor of shape `(batch_size, num_frames, height, width, channels)`. - encoder_hidden_states (`torch.Tensor`): - Text embeddings from Qwen2.5-VL of shape `(batch_size, sequence_length, text_dim)`. - pooled_text_embed (`torch.Tensor`): - Pooled text embeddings from CLIP of shape `(batch_size, pooled_text_dim)`. - timestep (`torch.Tensor`): - Timestep tensor of shape `(batch_size,)` or `(batch_size * num_frames,)`. - visual_rope_pos (`List[torch.Tensor]`): - List of tensors for visual rotary positional embeddings [temporal, height, width]. - text_rope_pos (`torch.Tensor`): - Tensor for text rotary positional embeddings. - scale_factor (`Tuple[float, float, float]`, defaults to `(1.0, 1.0, 1.0)`): - Scale factors for rotary positional embeddings. - sparse_params (`Dict[str, Any]`, *optional*): - Parameters for sparse attention. - return_dict (`bool`, defaults to `True`): - Whether to return a dictionary or a tensor. - - Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: - If `return_dict` is `True`, a [`~models.transformer_2d.Transformer2DModelOutput`] is returned, - otherwise a `tuple` where the first element is the sample tensor. - """ - batch_size, num_frames, height, width, channels = hidden_states.shape - - # 1. Process text embeddings - text_embed = self.text_embeddings(encoder_hidden_states) - time_embed = self.time_embeddings(timestep) - - # Add pooled text embedding to time embedding - pooled_embed = self.pooled_text_embeddings(pooled_text_embed) - time_embed = time_embed + pooled_embed - - # visual_embed shape: [batch_size, seq_len, model_dim] - visual_embed = self.visual_embeddings(hidden_states) - - # 3. Text rotary embeddings - text_rope = self.text_rope_embeddings(text_rope_pos) + hidden_states, # x + encoder_hidden_states, #text_embed + timestep, # time + pooled_projections, #pooled_text_embed, + visual_rope_pos, + text_rope_pos, + scale_factor=(1.0, 1.0, 1.0), + sparse_params=None, + return_dict=True, + ): + x = hidden_states + text_embed = encoder_hidden_states + time = timestep + pooled_text_embed = pooled_projections + + text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks( + text_embed, time, pooled_text_embed, x, text_rope_pos) - # 4. Text transformer blocks - i = 0 - for text_block in self.text_transformer_blocks: - if self.gradient_checkpointing and self.training: - text_embed = torch.utils.checkpoint.checkpoint( - text_block, text_embed, time_embed, text_rope, use_reentrant=False - ) - - else: - text_embed = text_block(text_embed, time_embed, text_rope) + for text_transformer_block in self.text_transformer_blocks: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) - i += 1 + visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks( + visual_embed, visual_rope_pos, scale_factor, sparse_params) - # 5. Prepare visual rope - visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) - - # visual_embed = visual_embed.reshape(visual_embed.shape[0], -1, visual_embed.shape[-1]) - # visual_rope = visual_rope.view(visual_rope.shape[0], -1, *list(visual_rope.shape[-4:])) - visual_embed = visual_embed.flatten(1, 3) - visual_rope = visual_rope.flatten(1, 3) + for visual_transformer_block in self.visual_transformer_blocks: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + + x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - # 6. Visual transformer blocks - i = 0 - for visual_block in self.visual_transformer_blocks: - if self.gradient_checkpointing and self.training: - visual_embed = torch.utils.checkpoint.checkpoint( - visual_block, - visual_embed, - text_embed, - time_embed, - visual_rope, - # visual_rope_flat, - sparse_params, - use_reentrant=False, - ) - else: - visual_embed = visual_block( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - - i += 1 - - # 7. Output projection - visual_embed = visual_embed.reshape(batch_size, num_frames, height // 2, width // 2, -1) - output = self.out_layer(visual_embed, text_embed, time_embed) - - if not return_dict: - return (output,) - - return Transformer2DModelOutput(sample=output) + if return_dict: + return Transformer2DModelOutput(sample=x) + + return x diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 9dbf31fea960..214b2b953c1c 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -300,7 +300,7 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 25, + num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, scheduler_scale: float = 10.0, @@ -354,6 +354,11 @@ def __call__( the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. """ + self.transformer.time_embeddings.reset_dtype() + self.transformer.text_rope_embeddings.reset_dtype() + self.transformer.visual_rope_embeddings.reset_dtype() + + dtype = self.transformer.dtype if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -394,7 +399,7 @@ def __call__( width=width, num_frames=num_frames, visual_cond=self.transformer.visual_cond, - dtype=self.transformer.dtype, + dtype=dtype, device=device, generator=generator, latents=latents, @@ -418,41 +423,39 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): - timestep = t.unsqueeze(0) + timestep = t.unsqueeze(0).flatten() - with torch.autocast(device_type="cuda", dtype=torch.bfloat16): - # print(latents.shape) + with torch.autocast(device_type="cuda", dtype=dtype): pred_velocity = self.transformer( - latents, - text_embeds["text_embeds"], - text_embeds["pooled_embed"], - timestep, - visual_rope_pos, - text_rope_pos, + hidden_states=latents, + encoder_hidden_states=text_embeds["text_embeds"], + pooled_projections=text_embeds["pooled_embed"], + timestep=timestep, + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=(1, 2, 2), sparse_params=None, - return_dict=False - )[0] - + return_dict=True + ).sample + if guidance_scale > 1.0 and negative_text_embeds is not None: uncond_pred_velocity = self.transformer( - latents, - negative_text_embeds["text_embeds"], - negative_text_embeds["pooled_embed"], - timestep, - visual_rope_pos, - negative_text_rope_pos, + hidden_states=latents, + encoder_hidden_states=negative_text_embeds["text_embeds"], + pooled_projections=negative_text_embeds["pooled_embed"], + timestep=timestep, + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), sparse_params=None, - return_dict=False - )[0] + return_dict=True + ).sample pred_velocity = uncond_pred_velocity + guidance_scale * ( pred_velocity - uncond_pred_velocity ) - latents = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] - latents = torch.cat([latents, visual_cond], dim=-1) + latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} From c8f3a36fba49799c21161858872f03ffde7bef57 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 14:39:59 +0000 Subject: [PATCH 04/70] rewrite Kandinsky5T2VPipeline to diffusers style --- .../kandinsky5/pipeline_kandinsky.py | 531 ++++++++++++++---- 1 file changed, 407 insertions(+), 124 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 214b2b953c1c..cea079251bc3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -75,6 +75,101 @@ ``` """ +# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +from typing import Any, Callable, Dict, List, Optional, Union + +import regex as re +import torch +from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +import torchvision +from torchvision.transforms import ToPILImage + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...loaders import KandinskyLoraLoaderMixin +from ...models import AutoencoderKLHunyuanVideo +from ...models.transformers import Kandinsky5Transformer3DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor +from ..pipeline_utils import DiffusionPipeline +from .pipeline_output import KandinskyPipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_ftfy_available(): + import ftfy + + +logger = logging.get_logger(__name__) + +EXAMPLE_DOC_STRING = """ + Examples: + + ```python + >>> import torch + >>> from diffusers import Kandinsky5T2VPipeline + >>> from diffusers.utils import export_to_video + + >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + >>> pipe = pipe.to("cuda") + + >>> prompt = "A cat and a dog baking a cake together in a kitchen." + >>> negative_prompt = "Bright tones, overexposed, static, blurred details" + + >>> output = pipe( + ... prompt=prompt, + ... negative_prompt=negative_prompt, + ... height=512, + ... width=768, + ... num_frames=25, + ... num_inference_steps=50, + ... guidance_scale=5.0, + ... ).frames[0] + >>> export_to_video(output, "output.mp4", fps=6) + ``` +""" + + +def basic_clean(text): + if is_ftfy_available(): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +def prompt_clean(text): + text = whitespace_clean(basic_clean(text)) + return text + class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): r""" @@ -96,9 +191,11 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): Frozen CLIP text encoder. tokenizer_2 ([`CLIPTokenizer`]): Tokenizer for CLIP. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded video latents. """ - model_cpu_offload_seq = "text_encoder->transformer->vae" + model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] def __init__( @@ -125,6 +222,7 @@ def __init__( self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) def _encode_prompt_qwen( self, @@ -132,9 +230,12 @@ def _encode_prompt_qwen( device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, max_sequence_length: int = 256, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] # Kandinsky specific prompt template prompt_template = "\n".join([ @@ -180,16 +281,19 @@ def _encode_prompt_qwen( embeds = embeds.repeat(1, num_videos_per_prompt, 1) embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - return embeds, cu_seqlens + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, num_videos_per_prompt: int = 1, + dtype: Optional[torch.dtype] = None, ): device = device or self._execution_device + dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt + prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -208,7 +312,7 @@ def _encode_prompt_clip( pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) - return pooled_embed + return pooled_embed.to(dtype) def encode_prompt( self, @@ -216,34 +320,151 @@ def encode_prompt( negative_prompt: Optional[Union[str, List[str]]] = None, do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + Whether to use classifier free guidance or not. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + Number of videos that should be generated per prompt. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + max_sequence_length (`int`, *optional*, defaults to 512): + Maximum sequence length for text encoding. + device: (`torch.device`, *optional*): + torch device + dtype: (`torch.dtype`, *optional*): + torch dtype + """ device = device or self._execution_device - prompt_embeds, prompt_cu_seqlens = self._encode_prompt_qwen(prompt, device, num_videos_per_prompt) - pooled_embed = self._encode_prompt_clip(prompt, device, num_videos_per_prompt) + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) + else: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds - if do_classifier_free_guidance: + if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - negative_prompt_embeds, negative_cu_seqlens = self._encode_prompt_qwen(negative_prompt, device, num_videos_per_prompt) - negative_pooled_embed = self._encode_prompt_clip(negative_prompt, device, num_videos_per_prompt) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) + negative_prompt_embeds_clip = self._encode_prompt_clip( + prompt=negative_prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) else: - negative_prompt_embeds = None - negative_pooled_embed = None + negative_prompt_embeds_qwen = None + negative_prompt_embeds_clip = None negative_cu_seqlens = None - text_embeds = { - "text_embeds": prompt_embeds, - "pooled_embed": pooled_embed, + prompt_embeds_dict = { + "text_embeds": prompt_embeds_qwen, + "pooled_embed": prompt_embeds_clip, } - negative_text_embeds = { - "text_embeds": negative_prompt_embeds, - "pooled_embed": negative_pooled_embed, + negative_prompt_embeds_dict = { + "text_embeds": negative_prompt_embeds_qwen, + "pooled_embed": negative_prompt_embeds_clip, } if do_classifier_free_guidance else None - return text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + + def check_inputs( + self, + prompt, + negative_prompt, + height, + width, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 16 != 0 or width % 16 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif negative_prompt is not None and ( + not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) + ): + raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, @@ -252,34 +473,31 @@ def prepare_latents( height: int = 480, width: int = 832, num_frames: int = 81, - visual_cond: bool = False, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: if latents is not None: - num_latent_frames = latents.shape[1] - latents = latents.to(device=device, dtype=dtype) - - else: - num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_latent_frames, - int(height) // self.vae_scale_factor_spatial, - int(width) // self.vae_scale_factor_spatial, - num_channels_latents, + return latents.to(device=device, dtype=dtype) + + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 + shape = ( + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + num_channels_latents, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - - if visual_cond: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + if self.transformer.visual_cond: # For visual conditioning, concatenate with zeros and mask visual_cond = torch.zeros_like(latents) visual_cond_mask = torch.zeros( @@ -291,26 +509,46 @@ def prepare_latents( return latents + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], + prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 121, + num_frames: int = 25, num_inference_steps: int = 50, guidance_scale: float = 5.0, scheduler_scale: float = 10.0, - num_videos_per_prompt: int = 1, - generator: Optional[torch.Generator] = None, + num_videos_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end: Optional[ + Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] + ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, **kwargs, ): r""" @@ -318,9 +556,10 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide the video generation. + The prompt or prompts to guide the video generation. If not defined, pass `prompt_embeds` instead. negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to avoid during video generation. + The prompt or prompts to avoid during video generation. If not defined, pass `negative_prompt_embeds` + instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `512`): The height in pixels of the generated video. width (`int`, defaults to `768`): @@ -335,82 +574,109 @@ def __call__( Scale factor for the custom flow matching scheduler. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. - generator (`torch.Generator`, *optional*): + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A torch generator to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`KandinskyPipelineOutput`]. - callback_on_step_end (`Callable`, *optional*): + callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function that is called at the end of each denoising step. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. + max_sequence_length (`int`, defaults to `512`): + The maximum sequence length for text encoding. Examples: Returns: [`~KandinskyPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images and the second element is a list of `bool`s - indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. + the first element is a list with the generated images. """ + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 0. Reset embeddings dtype self.transformer.time_embeddings.reset_dtype() self.transformer.text_rope_embeddings.reset_dtype() self.transformer.visual_rope_embeddings.reset_dtype() - - dtype = self.transformer.dtype - - if height % 16 != 0 or width % 16 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") - if isinstance(prompt, str): - batch_size = 1 - else: - batch_size = len(prompt) + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + negative_prompt, + height, + width, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) - device = self._execution_device - do_classifier_free_guidance = guidance_scale > 1.0 - if num_frames % self.vae_scale_factor_temporal != 1: logger.warning( f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." ) num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 num_frames = max(num_frames, 1) - - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - text_embeds, negative_text_embeds, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + self._guidance_scale = guidance_scale + self._interrupt = False + + device = self._execution_device + dtype = self.transformer.dtype + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] + + # 3. Encode input prompt + prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, negative_prompt=negative_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + max_sequence_length=max_sequence_length, device=device, + dtype=dtype, ) + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables num_channels_latents = 16 latents = self.prepare_latents( - batch_size=batch_size * num_videos_per_prompt, - num_channels_latents=16, - height=height, - width=width, - num_frames=num_frames, - visual_cond=self.transformer.visual_cond, - dtype=dtype, - device=device, - generator=generator, - latents=latents, + batch_size * num_videos_per_prompt, + num_channels_latents, + height, + width, + num_frames, + dtype, + device, + generator, + latents, ) - - visual_cond = latents[:, :, :, :, 16:] + # 6. Prepare rope positions + num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ - torch.arange(num_frames // 4 + 1, device=device), - torch.arange(height // 8 // 2, device=device), - torch.arange(width // 8 // 2, device=device), + torch.arange(num_latent_frames, device=device), + torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), + torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) @@ -421,52 +687,72 @@ def __call__( else None ) + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if self.interrupt: + continue + timestep = t.unsqueeze(0).flatten() - with torch.autocast(device_type="cuda", dtype=dtype): - pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=text_embeds["text_embeds"], - pooled_projections=text_embeds["pooled_embed"], - timestep=timestep, + + + # Predict noise residual + # with torch.autocast(device_type="cuda", dtype=dtype): + pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=text_rope_pos, + scale_factor=(1, 2, 2), + sparse_params=None, + return_dict=True + ).sample + + if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype), + pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), + timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=(1, 2, 2), + text_rope_pos=negative_text_rope_pos, + scale_factor=(1, 2, 2), sparse_params=None, return_dict=True ).sample - if guidance_scale > 1.0 and negative_text_embeds is not None: - uncond_pred_velocity = self.transformer( - hidden_states=latents, - encoder_hidden_states=negative_text_embeds["text_embeds"], - pooled_projections=negative_text_embeds["pooled_embed"], - timestep=timestep, - visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, - scale_factor=(1, 2, 2), - sparse_params=None, - return_dict=True - ).sample - - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) + pred_velocity = uncond_pred_velocity + guidance_scale * ( + pred_velocity - uncond_pred_velocity + ) - latents[:, :, :, :, :16] = self.scheduler.step(pred_velocity, t, latents[:, :, :, :, :16], return_dict=False)[0] + # Compute previous sample + latents[:, :, :, :, :16] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + )[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, timestep, callback_kwargs) + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + latents = callback_outputs.pop("latents", latents) - + prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict) + negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict) + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() - + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing latents = latents[:, :, :, :, :16] # 9. Decode latents to video @@ -477,26 +763,23 @@ def __call__( batch_size, num_videos_per_prompt, (num_frames - 1) // self.vae_scale_factor_temporal + 1, - height // 8, - width // 8, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, 16, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] - video = video.reshape(batch_size * num_videos_per_prompt, 16, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // 8, width // 8) + video = video.reshape( + batch_size * num_videos_per_prompt, + 16, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial + ) # Normalize and decode video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample - video = ((video.clamp(-1.0, 1.0) + 1.0) * 127.5).to(torch.uint8) - # Convert to output format - if output_type == "pil": - if num_frames == 1: - # Single image - video = [ToPILImage()(frame.squeeze(1)) for frame in video] - else: - # Video frames - video = [video[i] for i in range(video.shape[0])] - + video = self.video_processor.postprocess_video(video, output_type=output_type) else: video = latents From 723d149dc1dad0db009abcb210e671a775b23db6 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 17:00:23 +0000 Subject: [PATCH 05/70] add multiprompt support --- .../kandinsky5/pipeline_kandinsky.py | 40 ++++++++++++------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index cea079251bc3..a417d9967548 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -269,18 +269,21 @@ def _encode_prompt_qwen( output_hidden_states=True, )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) + attention_mask = inputs["attention_mask"][:, crop_start:] - embeds = embeds[attention_mask.bool()] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - # duplicate for each generation per prompt - batch_size = len(prompt) - seq_len = embeds.shape[0] // batch_size - embeds = embeds.reshape(batch_size, seq_len, -1) - embeds = embeds.repeat(1, num_videos_per_prompt, 1) - embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# # duplicate for each generation per prompt +# seq_len = embeds.shape[0] // batch_size +# embeds = embeds.reshape(batch_size, seq_len, -1) +# embeds = embeds.repeat(1, num_videos_per_prompt, 1) +# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) +# print(embeds.shape, cu_seqlens, "ENCODE PROMPT") + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( @@ -679,10 +682,10 @@ def __call__( torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] - text_rope_pos = torch.arange(prompt_cu_seqlens[-1].item(), device=device) + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) negative_text_rope_pos = ( - torch.arange(negative_cu_seqlens[-1].item(), device=device) + torch.arange(negative_cu_seqlens.diff().max().item(), device=device) if negative_cu_seqlens is not None else None ) @@ -696,12 +699,19 @@ def __call__( if self.interrupt: continue - timestep = t.unsqueeze(0).flatten() - - - - # Predict noise residual - # with torch.autocast(device_type="cuda", dtype=dtype): + timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) + + # Predict noise residual + # print( + # latents.shape, + # prompt_embeds_dict["text_embeds"].shape, + # prompt_embeds_dict["pooled_embed"].shape, + # timestep.shape, + # [el.shape for el in visual_rope_pos], + # text_rope_pos.shape, + # prompt_cu_seqlens, + # ) + pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), From 22e14bdac82fd5c100c4b1f34f5726c9c4aa4705 Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 10 Oct 2025 17:03:09 +0000 Subject: [PATCH 06/70] remove prints in pipeline --- .../kandinsky5/pipeline_kandinsky.py | 20 +------------------ 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a417d9967548..5d1eb7d60507 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -274,14 +274,6 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - -# # duplicate for each generation per prompt -# seq_len = embeds.shape[0] // batch_size -# embeds = embeds.reshape(batch_size, seq_len, -1) -# embeds = embeds.repeat(1, num_videos_per_prompt, 1) -# embeds = embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - -# print(embeds.shape, cu_seqlens, "ENCODE PROMPT") embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) return embeds.to(dtype), cu_seqlens @@ -642,7 +634,7 @@ def __call__( batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - + # 3. Encode input prompt prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( prompt=prompt, @@ -702,16 +694,6 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) # Predict noise residual - # print( - # latents.shape, - # prompt_embeds_dict["text_embeds"].shape, - # prompt_embeds_dict["pooled_embed"].shape, - # timestep.shape, - # [el.shape for el in visual_rope_pos], - # text_rope_pos.shape, - # prompt_cu_seqlens, - # ) - pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), From 70fa62baeaa019e7a47abb5e3a2662ba509d5bb8 Mon Sep 17 00:00:00 2001 From: leffff Date: Sun, 12 Oct 2025 21:59:23 +0000 Subject: [PATCH 07/70] add nabla attention --- .../transformers/transformer_kandinsky.py | 84 +++++++++++++++++-- .../kandinsky5/pipeline_kandinsky.py | 69 ++++++++++++++- 2 files changed, 142 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 3bbb9421f7ce..45d4ccdf9af3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -64,8 +64,8 @@ def get_freqs(dim, max_period=10000.0): def fractal_flatten(x, rope, shape, block_mask=False): if block_mask: pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) x = x.flatten(1, 2) rope = rope.flatten(1, 2) else: @@ -77,15 +77,15 @@ def fractal_flatten(x, rope, shape, block_mask=False): def fractal_unflatten(x, shape, block_mask=False): if block_mask: pixel_size = 8 - x = x.reshape(-1, pixel_size**2, *x.shape[1:]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, *x.shape[2:]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) else: x = x.reshape(*shape, *x.shape[2:]) return x def local_patching(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -112,7 +112,7 @@ def local_patching(x, shape, group_size, dim=0): def local_merge(x, shape, group_size, dim=0): - duration, height, width = shape + batch_size, duration, height, width = shape g1, g2, g3 = group_size x = x.reshape( *x.shape[:dim], @@ -138,6 +138,36 @@ def local_merge(x, shape, group_size, dim=0): return x +def nablaT_v2( + q: Tensor, + k: Tensor, + sta: Tensor, + thr: float = 0.9, +) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + + def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() @@ -392,6 +422,29 @@ def norm_qk(self, q, k): def attention(self, query, key, value): out = sdpa(q=query, k=key, v=value).flatten(-2, -1) return out + + def nabla(self, query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out def out_l(self, x): return self.out_layer(x) @@ -402,7 +455,10 @@ def forward(self, x, rope, sparse_params=None): query = apply_rotary(query, rope).type_as(query) key = apply_rotary(key, rope).type_as(key) - out = self.attention(query, key, value) + if sparse_params is not None: + out = self.nabla(query, key, value, sparse_params=sparse_params) + else: + out = self.attention(query, key, value) out = self.out_l(out) return out @@ -587,7 +643,18 @@ def __init__( num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, + attention_type: str = "regular", + attention_causal: bool = None, #Deffault for Nabla: false, + attention_local: bool = None, #Deffault for Nabla: false, + attention_glob:bool = None, #Deffault for Nabla: false, + attention_window: int = None, #Deffault for Nabla: 3 + attention_P: float = None, #Deffault for Nabla: 0.9 + attention_wT: int = None, #Deffault for Nabla: 11 + attention_wW:int = None, #Deffault for Nabla: 3 + attention_wH:int = None, #Deffault for Nabla: 3 + attention_add_sta: bool = None, #Deffault for Nabla: true + attention_method: str = None, #Deffault for Nabla: "topcdf" ): super().__init__() @@ -596,6 +663,7 @@ def __init__( self.model_dim = model_dim self.patch_size = patch_size self.visual_cond = visual_cond + self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim self.time_embeddings = TimeEmbeddings(model_dim, time_dim) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5d1eb7d60507..05230a604fa4 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -223,6 +223,66 @@ def __init__( self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) + + @staticmethod + def fast_sta_nabla( + T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" + ) -> torch.Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + + def get_sparse_params(self, sample, device): + assert self.transformer.config.patch_size[0] == 1 + B, T, H, W, _ = sample.shape + T, H, W = ( + T // self.transformer.config.patch_size[0], + H // self.transformer.config.patch_size[1], + W // self.transformer.config.patch_size[2], + ) + if self.transformer.config.attention_type == "nabla": + sta_mask = self.fast_sta_nabla( + T, H // 8, W // 8, + self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW, + device=device + ) + + sparse_params = { + "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), + "attention_type": self.transformer.config.attention_type, + "to_fractal": True, + "P": self.transformer.config.attention_P, + "wT": self.transformer.config.attention_wT, + "wW": self.transformer.config.attention_wW, + "wH": self.transformer.config.attention_wH, + "add_sta": self.transformer.config.attention_add_sta, + "visual_shape": (T, H, W), + "method": self.transformer.config.attention_method, + } + else: + sparse_params = None + + return sparse_params def _encode_prompt_qwen( self, @@ -681,8 +741,11 @@ def __call__( if negative_cu_seqlens is not None else None ) + + # 7. Sparse Params + sparse_params = self.get_sparse_params(latents, device) - # 7. Denoising loop + # 8. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) @@ -702,7 +765,7 @@ def __call__( visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample @@ -715,7 +778,7 @@ def __call__( visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), - sparse_params=None, + sparse_params=sparse_params, return_dict=True ).sample From 45240a7317d12228d16c3fad31920dbb939cc538 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 12:27:03 +0000 Subject: [PATCH 08/70] Wrap Transformer in Diffusers style --- .../transformers/transformer_kandinsky.py | 301 ++++++++++++------ .../kandinsky5/pipeline_kandinsky.py | 4 +- 2 files changed, 209 insertions(+), 96 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45d4ccdf9af3..4ba7e144030f 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -201,7 +201,7 @@ def apply_rotary(x, rope): return x_out.reshape(*x.shape).to(torch.bfloat16) -class TimeEmbeddings(nn.Module): +class Kandinsky5TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 @@ -225,7 +225,7 @@ def reset_dtype(self): self.freqs = get_freqs(self.model_dim // 2, self.max_period) -class TextEmbeddings(nn.Module): +class Kandinsky5TextEmbeddings(nn.Module): def __init__(self, text_dim, model_dim): super().__init__() self.in_layer = nn.Linear(text_dim, model_dim, bias=True) @@ -236,7 +236,7 @@ def forward(self, text_embed): return self.norm(text_embed).type_as(text_embed) -class VisualEmbeddings(nn.Module): +class Kandinsky5VisualEmbeddings(nn.Module): def __init__(self, visual_dim, model_dim, patch_size): super().__init__() self.patch_size = patch_size @@ -261,7 +261,7 @@ def forward(self, x): return self.in_layer(x) -class RoPE1D(nn.Module): +class Kandinsky5RoPE1D(nn.Module): def __init__(self, dim, max_pos=1024, max_period=10000.0): super().__init__() self.max_period = max_period @@ -286,7 +286,7 @@ def reset_dtype(self): self.args = torch.outer(pos, freq) -class RoPE3D(nn.Module): +class Kandinsky5RoPE3D(nn.Module): def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): super().__init__() self.axes_dims = axes_dims @@ -326,7 +326,7 @@ def reset_dtype(self): setattr(self, f'args_{i}', torch.outer(pos, freq)) -class Modulation(nn.Module): +class Kandinsky5Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() @@ -338,8 +338,63 @@ def __init__(self, time_dim, model_dim, num_params): def forward(self, x): return self.out_layer(self.activation(x)) + +class Kandinsky5SDPAAttentionProcessor(nn.Module): + """Custom attention processor for standard SDPA attention""" + + def __call__( + self, + attn, + query, + key, + value, + **kwargs, + ): + # Process attention with the given query, key, value tensors + out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + return out + + +class Kandinsky5NablaAttentionProcessor(nn.Module): + """Custom attention processor for Nabla attention""" -class MultiheadSelfAttentionEnc(nn.Module): + def __call__( + self, + attn, + query, + key, + value, + sparse_params=None, + **kwargs, + ): + if sparse_params is None: + raise ValueError("sparse_params is required for Nabla attention") + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out + + +class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -352,6 +407,9 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processor + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() def get_qkv(self, x): query = self.to_query(x) @@ -371,8 +429,14 @@ def norm_qk(self, q, k): return q, k def scaled_dot_product_attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def out_l(self, x): return self.out_layer(x) @@ -388,7 +452,8 @@ def forward(self, x, rope): out = self.out_l(out) return out -class MultiheadSelfAttentionDec(nn.Module): + +class Kandinsky5MultiheadSelfAttentionDec(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -401,6 +466,10 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processors + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() + self.nabla_processor = Kandinsky5NablaAttentionProcessor() def get_qkv(self, x): query = self.to_query(x) @@ -420,31 +489,25 @@ def norm_qk(self, q, k): return q, k def attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def nabla(self, query, key, value, sparse_params=None): - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - block_mask = nablaT_v2( - query, - key, - sparse_params["sta_mask"], - thr=sparse_params["P"], + # Use the processor + return self.nabla_processor( + attn=self, + query=query, + key=key, + value=value, + sparse_params=sparse_params, + **{} ) - out = ( - flex_attention( - query, - key, - value, - block_mask=block_mask - ) - .transpose(1, 2) - .contiguous() - ) - out = out.flatten(-2, -1) - return out def out_l(self, x): return self.out_layer(x) @@ -464,7 +527,7 @@ def forward(self, x, rope, sparse_params=None): return out -class MultiheadCrossAttention(nn.Module): +class Kandinsky5MultiheadCrossAttention(nn.Module): def __init__(self, num_channels, head_dim): super().__init__() assert num_channels % head_dim == 0 @@ -477,6 +540,9 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + + # Initialize attention processor + self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() def get_qkv(self, x, cond): query = self.to_query(x) @@ -496,8 +562,14 @@ def norm_qk(self, q, k): return q, k def attention(self, query, key, value): - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) - return out + # Use the processor + return self.sdpa_processor( + attn=self, + query=query, + key=key, + value=value, + **{} + ) def out_l(self, x): return self.out_layer(x) @@ -511,7 +583,7 @@ def forward(self, x, cond): return out -class FeedForward(nn.Module): +class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() self.in_layer = nn.Linear(dim, ff_dim, bias=False) @@ -522,11 +594,11 @@ def forward(self, x): return self.out_layer(self.activation(self.in_layer(x))) -class OutLayer(nn.Module): +class Kandinsky5OutLayer(nn.Module): def __init__(self, model_dim, time_dim, visual_dim, patch_size): super().__init__() self.patch_size = patch_size - self.modulation = Modulation(time_dim, model_dim, 2) + self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.out_layer = nn.Linear( model_dim, math.prod(patch_size) * visual_dim, bias=True @@ -561,19 +633,17 @@ def forward(self, visual_embed, text_embed, time_embed): ) return x - - -class TransformerEncoderBlock(nn.Module): +class Kandinsky5TransformerEncoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() - self.text_modulation = Modulation(time_dim, model_dim, 6) + self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = MultiheadSelfAttentionEnc(model_dim, head_dim) + self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.feed_forward = FeedForward(model_dim, ff_dim) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) @@ -589,19 +659,19 @@ def forward(self, x, time_embed, rope): return x -class TransformerDecoderBlock(nn.Module): +class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() - self.visual_modulation = Modulation(time_dim, model_dim, 9) + self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = MultiheadSelfAttentionDec(model_dim, head_dim) + self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.cross_attention = MultiheadCrossAttention(model_dim, head_dim) + self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.feed_forward = FeedForward(model_dim, ff_dim) + self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( @@ -645,16 +715,16 @@ def __init__( axes_dims=(16, 24, 24), visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, #Deffault for Nabla: false, - attention_local: bool = None, #Deffault for Nabla: false, - attention_glob:bool = None, #Deffault for Nabla: false, - attention_window: int = None, #Deffault for Nabla: 3 - attention_P: float = None, #Deffault for Nabla: 0.9 - attention_wT: int = None, #Deffault for Nabla: 11 - attention_wW:int = None, #Deffault for Nabla: 3 - attention_wH:int = None, #Deffault for Nabla: 3 - attention_add_sta: bool = None, #Deffault for Nabla: true - attention_method: str = None, #Deffault for Nabla: "topcdf" + attention_causal: bool = None, # Default for Nabla: false + attention_local: bool = None, # Default for Nabla: false + attention_glob: bool = None, # Default for Nabla: false + attention_window: int = None, # Default for Nabla: 3 + attention_P: float = None, # Default for Nabla: 0.9 + attention_wT: int = None, # Default for Nabla: 11 + attention_wW: int = None, # Default for Nabla: 3 + attention_wH: int = None, # Default for Nabla: 3 + attention_add_sta: bool = None, # Default for Nabla: true + attention_method: str = None, # Default for Nabla: "topcdf" ): super().__init__() @@ -666,31 +736,37 @@ def __init__( self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - self.time_embeddings = TimeEmbeddings(model_dim, time_dim) - self.text_embeddings = TextEmbeddings(in_text_dim, model_dim) - self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + + # Initialize embeddings + self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) + self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) + self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) + self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) - self.text_rope_embeddings = RoPE1D(head_dim) + # Initialize positional embeddings + self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) + self.visual_rope_embeddings = Kandinsky5RoPE3D(axes_dims) + + # Initialize transformer blocks self.text_transformer_blocks = nn.ModuleList( [ - TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) + Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks) ] ) - self.visual_rope_embeddings = RoPE3D(axes_dims) self.visual_transformer_blocks = nn.ModuleList( [ - TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) + Kandinsky5TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_visual_blocks) ] ) - self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + # Initialize output layer + self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) - def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, - text_rope_pos): + def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): + """Prepare text embeddings and related components""" text_embed = self.text_embeddings(text_embed) time_embed = self.time_embeddings(time) time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) @@ -699,8 +775,8 @@ def before_text_transformer_blocks(self, text_embed, time, pooled_text_embed, x, text_rope = text_rope.unsqueeze(dim=0) return text_embed, time_embed, text_rope, visual_embed - def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_factor, - sparse_params): + def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params): + """Prepare visual embeddings and related components""" visual_shape = visual_embed.shape[:-1] visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False @@ -708,44 +784,79 @@ def before_visual_transformer_blocks(self, visual_embed, visual_rope_pos, scale_ block_mask=to_fractal) return visual_embed, visual_shape, to_fractal, visual_rope - def after_blocks(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): + """Process text through transformer blocks""" + for text_transformer_block in self.text_transformer_blocks: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) + return text_embed + + def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): + """Process visual through transformer blocks""" + for visual_transformer_block in self.visual_transformer_blocks: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + return visual_embed + + def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + """Prepare the final output""" visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) x = self.out_layer(visual_embed, text_embed, time_embed) return x def forward( self, - hidden_states, # x - encoder_hidden_states, #text_embed - timestep, # time - pooled_projections, #pooled_text_embed, - visual_rope_pos, - text_rope_pos, - scale_factor=(1.0, 1.0, 1.0), - sparse_params=None, - return_dict=True, - ): + hidden_states: torch.FloatTensor, # x + encoder_hidden_states: torch.FloatTensor, # text_embed + timestep: Union[torch.Tensor, float, int], # time + pooled_projections: torch.FloatTensor, # pooled_text_embed + visual_rope_pos: Tuple[int, int, int], + text_rope_pos: torch.LongTensor, + scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), + sparse_params: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Transformer2DModelOutput, torch.FloatTensor]: + """ + Forward pass of the Kandinsky5 3D Transformer. + + Args: + hidden_states (`torch.FloatTensor`): Input visual states + encoder_hidden_states (`torch.FloatTensor`): Text embeddings + timestep (`torch.Tensor` or `float` or `int`): Current timestep + pooled_projections (`torch.FloatTensor`): Pooled text embeddings + visual_rope_pos (`Tuple[int, int, int]`): Position for visual RoPE + text_rope_pos (`torch.LongTensor`): Position for text RoPE + scale_factor (`Tuple[float, float, float]`, optional): Scale factor for RoPE + sparse_params (`Dict[str, Any]`, optional): Parameters for sparse attention + return_dict (`bool`, optional): Whether to return a dictionary + + Returns: + [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: + The output of the transformer + """ x = hidden_states text_embed = encoder_hidden_states time = timestep pooled_text_embed = pooled_projections - text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks( + # Prepare text embeddings and related components + text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( text_embed, time, pooled_text_embed, x, text_rope_pos) - for text_transformer_block in self.text_transformer_blocks: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) + # Process text through transformer blocks + text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope) - visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks( + # Prepare visual embeddings and related components + visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings( visual_embed, visual_rope_pos, scale_factor, sparse_params) - for visual_transformer_block in self.visual_transformer_blocks: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, - visual_rope, sparse_params) - - x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed) + # Process visual through transformer blocks + visual_embed = self.process_visual_transformer_blocks( + visual_embed, text_embed, time_embed, visual_rope, sparse_params) - if return_dict: - return Transformer2DModelOutput(sample=x) + # Prepare final output + x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - return x + if not return_dict: + return x + + return Transformer2DModelOutput(sample=x) \ No newline at end of file diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 05230a604fa4..12bc12cca205 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -263,7 +263,9 @@ def get_sparse_params(self, sample, device): if self.transformer.config.attention_type == "nabla": sta_mask = self.fast_sta_nabla( T, H // 8, W // 8, - self.transformer.config.attention_wT, self.transformer.config.attention_wH, self.transformer.config.attention_wW, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, device=device ) From 43bd1e81d2b0aba750477af04f0c3927c84e0761 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 14:41:50 +0000 Subject: [PATCH 09/70] fix license --- src/diffusers/models/transformers/transformer_kandinsky.py | 4 ++-- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 4ba7e144030f..01c9b258b7c3 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -859,4 +859,4 @@ def forward( if not return_dict: return x - return Transformer2DModelOutput(sample=x) \ No newline at end of file + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 12bc12cca205..a30484c701b0 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -1,4 +1,4 @@ -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. +# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From 149fd53df84c42100062def55d25ca02dc023979 Mon Sep 17 00:00:00 2001 From: leffff Date: Mon, 13 Oct 2025 22:38:03 +0000 Subject: [PATCH 10/70] fix prompt type --- .../kandinsky5/pipeline_kandinsky.py | 227 ++++++++++-------- 1 file changed, 130 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a30484c701b0..407dc127fda8 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -33,83 +33,6 @@ from .pipeline_output import KandinskyPipelineOutput -if is_torch_xla_available(): - import torch_xla.core.xla_model as xm - - XLA_AVAILABLE = True -else: - XLA_AVAILABLE = False - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -if is_ftfy_available(): - import ftfy - - -logger = logging.get_logger(__name__) - -EXAMPLE_DOC_STRING = """ - Examples: - - ```python - >>> import torch - >>> from diffusers import Kandinsky5T2VPipeline, Kandinsky5Transformer3DModel - >>> from diffusers.utils import export_to_video - - >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") - >>> pipe = pipe.to("cuda") - - >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - - >>> output = pipe( - ... prompt=prompt, - ... negative_prompt=negative_prompt, - ... height=512, - ... width=768, - ... num_frames=25, - ... num_inference_steps=50, - ... guidance_scale=5.0, - ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) - ``` -""" - -# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import html -from typing import Any, Callable, Dict, List, Optional, Union - -import regex as re -import torch -from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer -import torchvision -from torchvision.transforms import ToPILImage - -from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...loaders import KandinskyLoraLoaderMixin -from ...models import AutoencoderKLHunyuanVideo -from ...models.transformers import Kandinsky5Transformer3DModel -from ...schedulers import FlowMatchEulerDiscreteScheduler -from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring -from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor -from ..pipeline_utils import DiffusionPipeline -from .pipeline_output import KandinskyPipelineOutput - - if is_torch_xla_available(): import torch_xla.core.xla_model as xm @@ -137,23 +60,23 @@ >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." - >>> negative_prompt = "Bright tones, overexposed, static, blurred details" - + >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, ... height=512, ... width=768, - ... num_frames=25, + ... num_frames=121, ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=6) + >>> export_to_video(output, "output.mp4", fps=24) ``` """ def basic_clean(text): + """Clean text using ftfy if available and unescape HTML entities.""" if is_ftfy_available(): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) @@ -161,12 +84,14 @@ def basic_clean(text): def whitespace_clean(text): + """Normalize whitespace in text by replacing multiple spaces with single space.""" text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): + """Apply both basic cleaning and whitespace normalization to prompts.""" text = whitespace_clean(basic_clean(text)) return text @@ -228,6 +153,24 @@ def __init__( def fast_sta_nabla( T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" ) -> torch.Tensor: + """ + Create a sparse temporal attention (STA) mask for efficient video generation. + + This method generates a mask that limits attention to nearby frames and spatial positions, + reducing computational complexity for video generation. + + Args: + T (int): Number of temporal frames + H (int): Height in latent space + W (int): Width in latent space + wT (int): Temporal attention window size + wH (int): Height attention window size + wW (int): Width attention window size + device (str): Device to create tensor on + + Returns: + torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) + """ l = torch.Tensor([T, H, W]).amax() r = torch.arange(0, l, 1, dtype=torch.int16, device=device) mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() @@ -253,6 +196,19 @@ def fast_sta_nabla( return sta.reshape(T * H * W, T * H * W) def get_sparse_params(self, sample, device): + """ + Generate sparse attention parameters for the transformer based on sample dimensions. + + This method computes the sparse attention configuration needed for efficient + video processing in the transformer model. + + Args: + sample (torch.Tensor): Input sample tensor + device (torch.device): Device to place tensors on + + Returns: + Dict: Dictionary containing sparse attention parameters + """ assert self.transformer.config.patch_size[0] == 1 B, T, H, W, _ = sample.shape T, H, W = ( @@ -294,12 +250,28 @@ def _encode_prompt_qwen( max_sequence_length: int = 256, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using Qwen2.5-VL text encoder. + + This method processes the input prompt through the Qwen2.5-VL model to generate + text embeddings suitable for video generation. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + max_sequence_length (int): Maximum sequence length for tokenization + dtype (torch.dtype): Data type for embeddings + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths + """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template + # Kandinsky specific prompt template for detailed video description prompt_template = "\n".join([ "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", @@ -310,7 +282,7 @@ def _encode_prompt_qwen( "Pay attention to the order of key actions shown in the scene.<|im_end|>", "<|im_start|>user\n{}<|im_end|>", ]) - crop_start = 129 + crop_start = 129 # Position to start cropping from (system prompt tokens) full_texts = [prompt_template.format(p) for p in prompt] @@ -347,6 +319,21 @@ def _encode_prompt_clip( num_videos_per_prompt: int = 1, dtype: Optional[torch.dtype] = None, ): + """ + Encode prompt using CLIP text encoder. + + This method processes the input prompt through the CLIP model to generate + pooled embeddings that capture semantic information. + + Args: + prompt (Union[str, List[str]]): Input prompt or list of prompts + device (torch.device): Device to run encoding on + num_videos_per_prompt (int): Number of videos to generate per prompt + dtype (torch.dtype): Data type for embeddings + + Returns: + torch.Tensor: Pooled text embeddings from CLIP + """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype prompt = [prompt] if isinstance(prompt, str) else prompt @@ -386,6 +373,9 @@ def encode_prompt( r""" Encodes the prompt into text encoder hidden states. + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders + to create comprehensive text representations for video generation. + Args: prompt (`str` or `List[str]`, *optional*): prompt to be encoded @@ -410,11 +400,15 @@ def encode_prompt( torch device dtype: (`torch.dtype`, *optional*): torch dtype + + Returns: + Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information """ device = device or self._execution_device if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: @@ -438,7 +432,7 @@ def encode_prompt( prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" + negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt if prompt is not None and type(prompt) is not type(negative_prompt): @@ -492,6 +486,21 @@ def check_inputs( negative_prompt_embeds=None, callback_on_step_end_tensor_inputs=None, ): + """ + Validate input parameters for the pipeline. + + Args: + prompt: Input prompt + negative_prompt: Negative prompt for guidance + height: Video height + width: Video width + prompt_embeds: Pre-computed prompt embeddings + negative_prompt_embeds: Pre-computed negative prompt embeddings + callback_on_step_end_tensor_inputs: Callback tensor inputs + + Raises: + ValueError: If inputs are invalid + """ if height % 16 != 0 or width % 16 != 0: raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") @@ -535,6 +544,26 @@ def prepare_latents( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """ + Prepare initial latent variables for video generation. + + This method creates random noise latents or uses provided latents as starting point + for the denoising process. + + Args: + batch_size (int): Number of videos to generate + num_channels_latents (int): Number of channels in latent space + height (int): Height of generated video + width (int): Width of generated video + num_frames (int): Number of frames in video + dtype (torch.dtype): Data type for latents + device (torch.device): Device to create latents on + generator (torch.Generator): Random number generator + latents (torch.Tensor): Pre-existing latents to use + + Returns: + torch.Tensor: Prepared latent tensor + """ if latents is not None: return latents.to(device=device, dtype=dtype) @@ -568,18 +597,22 @@ def prepare_latents( @property def guidance_scale(self): + """Get the current guidance scale value.""" return self._guidance_scale @property def do_classifier_free_guidance(self): + """Check if classifier-free guidance is enabled.""" return self._guidance_scale > 1.0 @property def num_timesteps(self): + """Get the number of denoising timesteps.""" return self._num_timesteps @property def interrupt(self): + """Check if generation has been interrupted.""" return self._interrupt @torch.no_grad() @@ -590,10 +623,10 @@ def __call__( negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 512, width: int = 768, - num_frames: int = 25, + num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, - scheduler_scale: float = 10.0, + scheduler_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -715,7 +748,7 @@ def __call__( timesteps = self.scheduler.timesteps # 5. Prepare latent variables - num_channels_latents = 16 + num_channels_latents = self.transformer.config.in_visual_dim latents = self.prepare_latents( batch_size * num_videos_per_prompt, num_channels_latents, @@ -728,7 +761,7 @@ def __call__( latents, ) - # 6. Prepare rope positions + # 6. Prepare rope positions for positional encoding num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 visual_rope_pos = [ torch.arange(num_latent_frames, device=device), @@ -744,7 +777,7 @@ def __call__( else None ) - # 7. Sparse Params + # 7. Sparse Params for efficient attention sparse_params = self.get_sparse_params(latents, device) # 8. Denoising loop @@ -788,9 +821,9 @@ def __call__( pred_velocity - uncond_pred_velocity ) - # Compute previous sample - latents[:, :, :, :, :16] = self.scheduler.step( - pred_velocity, t, latents[:, :, :, :, :16], return_dict=False + # Compute previous sample using the scheduler + latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( + pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False )[0] if callback_on_step_end is not None: @@ -809,8 +842,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - # 8. Post-processing - latents = latents[:, :, :, :, :16] + # 8. Post-processing - extract main latents + latents = latents[:, :, :, :, :num_channels_latents] # 9. Decode latents to video if output_type != "latent": @@ -822,18 +855,18 @@ def __call__( (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, - 16, + num_channels_latents, ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] video = video.reshape( batch_size * num_videos_per_prompt, - 16, + num_channels_latents, (num_frames - 1) // self.vae_scale_factor_temporal + 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial ) - # Normalize and decode + # Normalize and decode through VAE video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample video = self.video_processor.postprocess_video(video, output_type=output_type) From 7af80e9ffcf4daef408d0f1c99b115c70ae73756 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 11:24:24 +0000 Subject: [PATCH 11/70] add gradient checkpointing and peft support --- .../transformers/transformer_kandinsky.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 01c9b258b7c3..6dec8d93ac9e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -22,6 +22,7 @@ from torch import BoolTensor, IntTensor, Tensor, nn from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, flex_attention) +from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -694,11 +695,12 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin): +class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): """ A 3D Diffusion Transformer model for video-like data. """ - + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -764,6 +766,7 @@ def __init__( # Initialize output layer self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.gradient_checkpointing = False def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): """Prepare text embeddings and related components""" @@ -787,13 +790,20 @@ def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, + visual_rope, sparse_params) + else: + visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, visual_rope, sparse_params) return visual_embed From 04efb19b1aeba3b41b7b1bd6d0353a1715c0f839 Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 12:14:37 +0000 Subject: [PATCH 12/70] add usage example --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 407dc127fda8..38d94ded42ad 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -55,12 +55,20 @@ >>> import torch >>> from diffusers import Kandinsky5T2VPipeline >>> from diffusers.utils import export_to_video - - >>> pipe = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V") + + >>> # Available models: + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers + >>> # ai-forever/Kandinsky-5.0-T2V-Lite-pretrain-5s-Diffusers + + >>> model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers" + >>> pipe = Kandinsky5T2VPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, @@ -70,7 +78,8 @@ ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - >>> export_to_video(output, "output.mp4", fps=24) + + >>> export_to_video(output, "output.mp4", fps=24, quality=9) ``` """ From 235f0d5df8a7d9842c63d458044ea823e921c8a8 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:53:32 +0300 Subject: [PATCH 13/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 38d94ded42ad..73868c972c32 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -13,7 +13,7 @@ # limitations under the License. import html -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import regex as re import torch From 88a8eea0962a3d209039e01c30d7601d14343ce0 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:53:47 +0300 Subject: [PATCH 14/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 73868c972c32..3840ad11dd5f 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -17,7 +17,7 @@ import regex as re import torch -from transformers import Qwen2TokenizerFast, Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, AutoProcessor, CLIPTextModel, CLIPTokenizer +from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer import torchvision from torchvision.transforms import ToPILImage From f52f3b45b75e461cbd9a28f280cdbad015059420 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:54:10 +0300 Subject: [PATCH 15/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3840ad11dd5f..39306cb9e812 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,6 @@ import regex as re import torch from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer -import torchvision from torchvision.transforms import ToPILImage from ...callbacks import MultiPipelineCallbacks, PipelineCallback From 0190e55641e70ab65f656b2499ee325ce2149f83 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 21:54:21 +0300 Subject: [PATCH 16/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 39306cb9e812..3a8628a1b339 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,6 @@ import regex as re import torch from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer -from torchvision.transforms import ToPILImage from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import KandinskyLoraLoaderMixin From d62dffcb212ea6f6281615f23230d77de3efc988 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Tue, 14 Oct 2025 23:25:14 +0300 Subject: [PATCH 17/70] Update src/diffusers/models/transformers/transformer_kandinsky.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 6dec8d93ac9e..24b2c4ae99b6 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -15,7 +15,6 @@ import math from typing import Any, Dict, List, Optional, Tuple, Union -from einops import rearrange import torch import torch.nn as nn import torch.nn.functional as F From 7084106eaaa9b998efd520e72b4a69a6e2dd90cf Mon Sep 17 00:00:00 2001 From: leffff Date: Tue, 14 Oct 2025 20:38:40 +0000 Subject: [PATCH 18/70] remove unused imports --- .../transformers/transformer_kandinsky.py | 250 ++++++++++-------- 1 file changed, 142 insertions(+), 108 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 24b2c4ae99b6..ac2fe58d60b4 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,21 +19,27 @@ import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import (BlockMask, _mask_mod_signature, - flex_attention) -from ...loaders import FromOriginalModelMixin, PeftAdapterMixin +from torch.nn.attention.flex_attention import ( + BlockMask, + flex_attention, +) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin -from ...utils import (USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, - unscale_lora_layers) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.torch_utils import maybe_allow_in_graph -from .._modeling_parallel import ContextParallelInput, ContextParallelOutput -from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward -from ..attention_dispatch import dispatch_attention_fn +from ..attention import AttentionMixin, FeedForward from ..cache_utils import CacheMixin -from ..embeddings import (PixArtAlphaTextProjection, TimestepEmbedding, - Timesteps, get_1d_rotary_pos_embed) +from ..embeddings import ( + TimestepEmbedding, + get_1d_rotary_pos_embed, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm @@ -95,7 +101,7 @@ def local_patching(x, shape, group_size, dim=0): g2, width // g3, g3, - *x.shape[dim + 3 :] + *x.shape[dim + 3 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -105,7 +111,7 @@ def local_patching(x, shape, group_size, dim=0): dim + 1, dim + 3, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) return x @@ -122,7 +128,7 @@ def local_merge(x, shape, group_size, dim=0): g1, g2, g3, - *x.shape[dim + 2 :] + *x.shape[dim + 2 :], ) x = x.permute( *range(len(x.shape[:dim])), @@ -132,7 +138,7 @@ def local_merge(x, shape, group_size, dim=0): dim + 4, dim + 2, dim + 5, - *range(dim + 6, len(x.shape)) + *range(dim + 6, len(x.shape)), ) x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) return x @@ -172,15 +178,7 @@ def sdpa(q, k, v): query = q.transpose(1, 2).contiguous() key = k.transpose(1, 2).contiguous() value = v.transpose(1, 2).contiguous() - out = ( - F.scaled_dot_product_attention( - query, - key, - value - ) - .transpose(1, 2) - .contiguous() - ) + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() return out @@ -279,7 +277,7 @@ def forward(self, pos): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) @@ -307,9 +305,15 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): args = torch.cat( [ - args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), - args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), - args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), + args_t.view(1, duration, 1, 1, -1).repeat( + batch_size, 1, height, width, 1 + ), + args_h.view(1, 1, height, 1, -1).repeat( + batch_size, duration, 1, width, 1 + ), + args_w.view(1, 1, 1, width, -1).repeat( + batch_size, duration, height, 1, 1 + ), ], dim=-1, ) @@ -318,12 +322,12 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): rope = torch.stack([cosine, -sine, sine, cosine], dim=-1) rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - + def reset_dtype(self): for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) - setattr(self, f'args_{i}', torch.outer(pos, freq)) + setattr(self, f"args_{i}", torch.outer(pos, freq)) class Kandinsky5Modulation(nn.Module): @@ -341,7 +345,7 @@ def forward(self, x): class Kandinsky5SDPAAttentionProcessor(nn.Module): """Custom attention processor for standard SDPA attention""" - + def __call__( self, attn, @@ -357,7 +361,7 @@ def __call__( class Kandinsky5NablaAttentionProcessor(nn.Module): """Custom attention processor for Nabla attention""" - + def __call__( self, attn, @@ -369,11 +373,11 @@ def __call__( ): if sparse_params is None: raise ValueError("sparse_params is required for Nabla attention") - + query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() - + block_mask = nablaT_v2( query, key, @@ -381,12 +385,7 @@ def __call__( thr=sparse_params["P"], ) out = ( - flex_attention( - query, - key, - value, - block_mask=block_mask - ) + flex_attention(query, key, value, block_mask=block_mask) .transpose(1, 2) .contiguous() ) @@ -407,7 +406,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -430,13 +429,7 @@ def norm_qk(self, q, k): def scaled_dot_product_attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -466,7 +459,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processors self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() self.nabla_processor = Kandinsky5NablaAttentionProcessor() @@ -490,14 +483,8 @@ def norm_qk(self, q, k): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) - + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) + def nabla(self, query, key, value, sparse_params=None): # Use the processor return self.nabla_processor( @@ -506,7 +493,7 @@ def nabla(self, query, key, value, sparse_params=None): key=key, value=value, sparse_params=sparse_params, - **{} + **{}, ) def out_l(self, x): @@ -540,7 +527,7 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - + # Initialize attention processor self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() @@ -563,13 +550,7 @@ def norm_qk(self, q, k): def attention(self, query, key, value): # Use the processor - return self.sdpa_processor( - attn=self, - query=query, - key=key, - value=value, - **{} - ) + return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) def out_l(self, x): return self.out_layer(x) @@ -605,7 +586,9 @@ def __init__(self, model_dim, time_dim, visual_dim, patch_size): ) def forward(self, visual_embed, text_embed, time_embed): - shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + shift, scale = torch.chunk( + self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) visual_embed = apply_scale_shift_norm( self.norm, visual_embed, @@ -646,7 +629,9 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + self_attn_params, ff_params = torch.chunk( + self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 + ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) out = self.self_attention(out, rope) @@ -678,26 +663,40 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.self_attention_norm, visual_embed, scale, shift + ) visual_out = self.self_attention(visual_out, rope, sparse_params) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.cross_attention_norm, visual_embed, scale, shift + ) visual_out = self.cross_attention(visual_out, text_embed) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = apply_scale_shift_norm( + self.feed_forward_norm, visual_embed, scale, shift + ) visual_out = self.feed_forward(visual_out) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) return visual_embed -class Kandinsky5Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin): +class Kandinsky5Transformer3DModel( + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + CacheMixin, + AttentionMixin, +): """ A 3D Diffusion Transformer model for video-like data. """ + _supports_gradient_checkpointing = True @register_to_config @@ -714,21 +713,21 @@ def __init__( num_text_blocks=2, num_visual_blocks=32, axes_dims=(16, 24, 24), - visual_cond=False, + visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, # Default for Nabla: false + attention_local: bool = None, # Default for Nabla: false + attention_glob: bool = None, # Default for Nabla: false + attention_window: int = None, # Default for Nabla: 3 + attention_P: float = None, # Default for Nabla: 0.9 + attention_wT: int = None, # Default for Nabla: 11 + attention_wW: int = None, # Default for Nabla: 3 + attention_wH: int = None, # Default for Nabla: 3 + attention_add_sta: bool = None, # Default for Nabla: true + attention_method: str = None, # Default for Nabla: "topcdf" ): super().__init__() - + head_dim = sum(axes_dims) self.in_visual_dim = in_visual_dim self.model_dim = model_dim @@ -737,12 +736,14 @@ def __init__( self.attention_type = attention_type visual_embed_dim = 2 * in_visual_dim + 1 if visual_cond else in_visual_dim - + # Initialize embeddings self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) + self.visual_embeddings = Kandinsky5VisualEmbeddings( + visual_embed_dim, model_dim, patch_size + ) # Initialize positional embeddings self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) @@ -764,10 +765,14 @@ def __init__( ) # Initialize output layer - self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) + self.out_layer = Kandinsky5OutLayer( + model_dim, time_dim, out_visual_dim, patch_size + ) self.gradient_checkpointing = False - def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_rope_pos): + def prepare_text_embeddings( + self, text_embed, time, pooled_text_embed, x, text_rope_pos + ): """Prepare text embeddings and related components""" text_embed = self.text_embeddings(text_embed) time_embed = self.time_embeddings(time) @@ -777,38 +782,58 @@ def prepare_text_embeddings(self, text_embed, time, pooled_text_embed, x, text_r text_rope = text_rope.unsqueeze(dim=0) return text_embed, time_embed, text_rope, visual_embed - def prepare_visual_embeddings(self, visual_embed, visual_rope_pos, scale_factor, sparse_params): + def prepare_visual_embeddings( + self, visual_embed, visual_rope_pos, scale_factor, sparse_params + ): """Prepare visual embeddings and related components""" visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor + ) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, - block_mask=to_fractal) + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal + ) return visual_embed, visual_shape, to_fractal, visual_rope def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): """Process text through transformer blocks""" for text_transformer_block in self.text_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func(text_transformer_block, text_embed, time_embed, text_rope) + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) else: text_embed = text_transformer_block(text_embed, time_embed, text_rope) return text_embed - def process_visual_transformer_blocks(self, visual_embed, text_embed, time_embed, visual_rope, sparse_params): + def process_visual_transformer_blocks( + self, visual_embed, text_embed, time_embed, visual_rope, sparse_params + ): """Process visual through transformer blocks""" for visual_transformer_block in self.visual_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func(visual_transformer_block, visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) else: - visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed, - visual_rope, sparse_params) + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) return visual_embed - def prepare_output(self, visual_embed, visual_shape, to_fractal, text_embed, time_embed): + def prepare_output( + self, visual_embed, visual_shape, to_fractal, text_embed, time_embed + ): """Prepare the final output""" - visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal + ) x = self.out_layer(visual_embed, text_embed, time_embed) return x @@ -846,25 +871,34 @@ def forward( text_embed = encoder_hidden_states time = timestep pooled_text_embed = pooled_projections - + # Prepare text embeddings and related components text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos) + text_embed, time, pooled_text_embed, x, text_rope_pos + ) # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks(text_embed, time_embed, text_rope) + text_embed = self.process_text_transformer_blocks( + text_embed, time_embed, text_rope + ) # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params) + visual_embed, visual_shape, to_fractal, visual_rope = ( + self.prepare_visual_embeddings( + visual_embed, visual_rope_pos, scale_factor, sparse_params + ) + ) # Process visual through transformer blocks visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params) - + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + # Prepare final output - x = self.prepare_output(visual_embed, visual_shape, to_fractal, text_embed, time_embed) - + x = self.prepare_output( + visual_embed, visual_shape, to_fractal, text_embed, time_embed + ) + if not return_dict: return x From b615d5cb131243e20cd40453fd6ceb874a092b25 Mon Sep 17 00:00:00 2001 From: leffff Date: Wed, 15 Oct 2025 18:09:23 +0000 Subject: [PATCH 19/70] add 10 second models support --- src/diffusers/models/transformers/transformer_kandinsky.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index ac2fe58d60b4..8d2bae11cbfa 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -361,7 +361,8 @@ def __call__( class Kandinsky5NablaAttentionProcessor(nn.Module): """Custom attention processor for Nabla attention""" - + + @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) def __call__( self, attn, From 588c12ab98d67be2c4dd8234877b3c4b16cac965 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:38:02 +0300 Subject: [PATCH 20/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3a8628a1b339..3d0d68cbe93b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -303,7 +303,6 @@ def _encode_prompt_qwen( padding=True, ).to(device) - with torch.no_grad(): embeds = self.text_encoder( input_ids=inputs["input_ids"], return_dict=True, From 327ab84d1923518ecc5314831254cfd70faf99e1 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 06:50:57 +0000 Subject: [PATCH 21/70] remove no_grad and simplified prompt paddings --- .../kandinsky5/pipeline_kandinsky.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3d0d68cbe93b..d4470a43d578 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -17,6 +17,7 @@ import regex as re import torch +from torch.nn import functional as F from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer from ...callbacks import MultiPipelineCallbacks, PipelineCallback @@ -303,17 +304,19 @@ def _encode_prompt_qwen( padding=True, ).to(device) - embeds = self.text_encoder( - input_ids=inputs["input_ids"], - return_dict=True, - output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] - + embeds = self.text_encoder( + input_ids=inputs["input_ids"], + return_dict=True, + output_hidden_states=True, + )["hidden_states"][-1][:, crop_start:] + batch_size = len(prompt) attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) - cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + # cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) + embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) return embeds.to(dtype), cu_seqlens @@ -354,8 +357,7 @@ def _encode_prompt_clip( return_tensors="pt", ).to(device) - with torch.no_grad(): - pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] + pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] # duplicate for each generation per prompt batch_size = len(prompt) From 9b06afba6b446352b9249a7f632af388174dd6ba Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:54:00 +0300 Subject: [PATCH 22/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3d0d68cbe93b..58ba3270a5f3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -314,7 +314,7 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) - embeds = torch.cat([embeds[i].unsqueeze(dim=0).repeat(num_videos_per_prompt, 1, 1) for i in range(batch_size)], dim=0) + embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From 28458d0caf929b90bc36df7f7004dd00fa607517 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 09:57:56 +0300 Subject: [PATCH 23/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 58ba3270a5f3..850795ada162 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -313,7 +313,7 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, crop_start:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) - cu_seqlens = torch.cat([torch.zeros_like(cu_seqlens)[:1], cu_seqlens]).to(dtype=torch.int32) + cu_seqlens =F.pad(cu_seqlens, (1, 0), value=0)).to(dtype=torch.int32) embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From cd3cc6156ea949e0a620b893660ad96933691f77 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 07:14:47 +0000 Subject: [PATCH 24/70] moved template to __init__ --- .../kandinsky5/pipeline_kandinsky.py | 40 +++++++++---------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 6ebedd04e830..bdf7e41df919 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -152,6 +152,16 @@ def __init__( tokenizer_2=tokenizer_2, scheduler=scheduler, ) + + self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>"]) + self.prompt_template_encode_start_idx = 129 self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio @@ -276,29 +286,14 @@ def _encode_prompt_qwen( """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] - # Kandinsky specific prompt template for detailed video description - prompt_template = "\n".join([ - "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", - "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", - "Describe the location of the video, main characters or objects and their action.", - "Describe the dynamism of the video and presented actions.", - "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", - "Describe the visual effects, postprocessing and transitions if they are presented in the video.", - "Pay attention to the order of key actions shown in the scene.<|im_end|>", - "<|im_start|>user\n{}<|im_end|>", - ]) - crop_start = 129 # Position to start cropping from (system prompt tokens) - - full_texts = [prompt_template.format(p) for p in prompt] + full_texts = [self.prompt_template.format(p) for p in prompt] inputs = self.tokenizer( text=full_texts, images=None, videos=None, - max_length=max_sequence_length + crop_start, + max_length=max_sequence_length + self.prompt_template_encode_start_idx, truncation=True, return_tensors="pt", padding=True, @@ -308,11 +303,11 @@ def _encode_prompt_qwen( input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, - )["hidden_states"][-1][:, crop_start:] + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] batch_size = len(prompt) - attention_mask = inputs["attention_mask"][:, crop_start:] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) @@ -343,8 +338,6 @@ def _encode_prompt_clip( """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(p) for p in prompt] inputs = self.tokenizer_2( prompt, @@ -357,7 +350,6 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - # duplicate for each generation per prompt batch_size = len(prompt) pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) @@ -421,6 +413,8 @@ def encode_prompt( batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] if prompt_embeds is None: + prompt = [prompt_clean(p) for p in prompt] + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, @@ -452,6 +446,8 @@ def encode_prompt( f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) + + negative_prompt = [prompt_clean(p) for p in negative_prompt] negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( prompt=negative_prompt, From 4450265bf76ee29ae2cbd7371d1237b1b4db24cf Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:19:26 +0300 Subject: [PATCH 25/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index bdf7e41df919..ff674b10ec1b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -260,7 +260,7 @@ def get_sparse_params(self, sample, device): return sparse_params - def _encode_prompt_qwen( + def _get_qwen_prompt_embeds( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From b9a3be2a152e0135ef0f0739e9aa62938a7d8dec Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:19:45 +0300 Subject: [PATCH 26/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index ff674b10ec1b..3e61ae0bf2c6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -314,7 +314,7 @@ def _get_qwen_prompt_embeds( return embeds.to(dtype), cu_seqlens - def _encode_prompt_clip( + def _get_clip_prompt_embeds( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From 78a23b9ddefa4199c1218b0ee0330785b6d5f43e Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Thu, 16 Oct 2025 10:34:59 +0300 Subject: [PATCH 27/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d2bae11cbfa..b8723bfe86ea 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -335,8 +335,6 @@ def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, num_params * model_dim) - self.out_layer.weight.data.zero_() - self.out_layer.bias.data.zero_() @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): From 56b90b10ef1fe17d7aae3cdbb65025084177fc27 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 07:35:17 +0000 Subject: [PATCH 28/70] moved sdps inside processor --- .../models/transformers/transformer_kandinsky.py | 15 ++++++--------- .../pipelines/kandinsky5/pipeline_kandinsky.py | 4 ++-- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d2bae11cbfa..680b456df3f7 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,14 +174,6 @@ def nablaT_v2( ) -def sdpa(q, k, v): - query = q.transpose(1, 2).contiguous() - key = k.transpose(1, 2).contiguous() - value = v.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous() - return out - - @torch.autocast(device_type="cuda", dtype=torch.float32) def apply_scale_shift_norm(norm, x, scale, shift): return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) @@ -355,7 +347,12 @@ def __call__( **kwargs, ): # Process attention with the given query, key, value tensors - out = sdpa(q=query, k=key, v=value).flatten(-2, -1) + + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) + return out diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3e61ae0bf2c6..bdf7e41df919 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -260,7 +260,7 @@ def get_sparse_params(self, sample, device): return sparse_params - def _get_qwen_prompt_embeds( + def _encode_prompt_qwen( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, @@ -314,7 +314,7 @@ def _get_qwen_prompt_embeds( return embeds.to(dtype), cu_seqlens - def _get_clip_prompt_embeds( + def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, From 31a1474378a0ae3fe22bc626f7fe274c99ed30fd Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 08:46:34 +0000 Subject: [PATCH 29/70] remove oneline function --- .../transformers/transformer_kandinsky.py | 91 ++++++++++++------- 1 file changed, 59 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index febe6cff7ae7..bed1938ae34d 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -174,16 +174,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_scale_shift_norm(norm, x, scale, shift): - return (norm(x) * (scale + 1.0) + shift).to(torch.bfloat16) - - -@torch.autocast(device_type="cuda", dtype=torch.float32) -def apply_gate_sum(x, out, gate): - return (x + gate * out).to(torch.bfloat16) - - @torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) @@ -327,6 +317,8 @@ def __init__(self, time_dim, model_dim, num_params): super().__init__() self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, num_params * model_dim) + self.out_layer.weight.data.zero_() + self.out_layer.bias.data.zero_() @torch.autocast(device_type="cuda", dtype=torch.float32) def forward(self, x): @@ -585,12 +577,9 @@ def forward(self, visual_embed, text_embed, time_embed): shift, scale = torch.chunk( self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) - visual_embed = apply_scale_shift_norm( - self.norm, - visual_embed, - scale[:, None, None], - shift[:, None, None], - ).type_as(visual_embed) + + visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed) + x = self.out_layer(visual_embed) batch_size, duration, height, width, _ = x.shape @@ -629,17 +618,59 @@ def forward(self, x, time_embed, rope): self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.self_attention(out, rope) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = (self.feed_forward_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.feed_forward(out) - x = apply_gate_sum(x, out, gate) + x = (x.float() + gate.float() * out.float()).type_as(x) + return x +# class Kandinsky5TransformerDecoderBlock(nn.Module): +# def __init__(self, model_dim, time_dim, ff_dim, head_dim): +# super().__init__() +# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) + +# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + +# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + +# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) +# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) + +# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): +# self_attn_params, cross_attn_params, ff_params = torch.chunk( +# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 +# ) +# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.self_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.self_attention(visual_out, rope, sparse_params) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.cross_attention_norm, visual_embed, scale, shift +# ) +# visual_out = self.cross_attention(visual_out, text_embed) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + +# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) +# visual_out = apply_scale_shift_norm( +# self.feed_forward_norm, visual_embed, scale, shift +# ) +# visual_out = self.feed_forward(visual_out) +# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) +# return visual_embed + + class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -658,26 +689,22 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): self_attn_params, cross_attn_params, ff_params = torch.chunk( self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 ) + shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.self_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.self_attention(visual_out, rope, sparse_params) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.cross_attention_norm, visual_embed, scale, shift - ) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.cross_attention(visual_out, text_embed) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = apply_scale_shift_norm( - self.feed_forward_norm, visual_embed, scale, shift - ) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) visual_out = self.feed_forward(visual_out) - visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) + return visual_embed From 894aa98a2753dfc448f4398cf9a4fd256f763a61 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 09:17:39 +0000 Subject: [PATCH 30/70] remove reset_dtype methods --- .../transformers/transformer_kandinsky.py | 20 +++---------------- .../kandinsky5/pipeline_kandinsky.py | 5 ----- 2 files changed, 3 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index bed1938ae34d..8d3b4fac513e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -189,7 +189,8 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): self.max_period = max_period self.register_buffer( "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) + ) + self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() self.out_layer = nn.Linear(time_dim, time_dim, bias=True) @@ -199,10 +200,7 @@ def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) - return time_embed - - def reset_dtype(self): - self.freqs = get_freqs(self.model_dim // 2, self.max_period) + return time_embed class Kandinsky5TextEmbeddings(nn.Module): @@ -260,11 +258,6 @@ def forward(self, pos): rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - def reset_dtype(self): - freq = get_freqs(self.dim // 2, self.max_period).to(self.args.device) - pos = torch.arange(self.max_pos, dtype=freq.dtype, device=freq.device) - self.args = torch.outer(pos, freq) - class Kandinsky5RoPE3D(nn.Module): def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): @@ -305,12 +298,6 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): rope = rope.view(*rope.shape[:-1], 2, 2) return rope.unsqueeze(-4) - def reset_dtype(self): - for i, (axes_dim, ax_max_pos) in enumerate(zip(self.axes_dims, self.max_pos)): - freq = get_freqs(axes_dim // 2, self.max_period).to(self.args_0.device) - pos = torch.arange(ax_max_pos, dtype=freq.dtype, device=freq.device) - setattr(self, f"args_{i}", torch.outer(pos, freq)) - class Kandinsky5Modulation(nn.Module): def __init__(self, time_dim, model_dim, num_params): @@ -337,7 +324,6 @@ def __call__( **kwargs, ): # Process attention with the given query, key, value tensors - query = query.transpose(1, 2).contiguous() key = key.transpose(1, 2).contiguous() value = value.transpose(1, 2).contiguous() diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index bdf7e41df919..b1f7924e9b9f 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -695,11 +695,6 @@ def __call__( if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs - # 0. Reset embeddings dtype - self.transformer.time_embeddings.reset_dtype() - self.transformer.text_rope_embeddings.reset_dtype() - self.transformer.visual_rope_embeddings.reset_dtype() - # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, From c8be08149e80ae22e7a7d3b4a1f2413a9f149690 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 09:31:12 +0000 Subject: [PATCH 31/70] Transformer: move all methods to forward --- .../transformers/transformer_kandinsky.py | 185 +++++------------- 1 file changed, 47 insertions(+), 138 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 8d3b4fac513e..45e4238cfb51 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -616,47 +616,6 @@ def forward(self, x, time_embed, rope): return x -# class Kandinsky5TransformerDecoderBlock(nn.Module): -# def __init__(self, model_dim, time_dim, ff_dim, head_dim): -# super().__init__() -# self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) - -# self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) - -# self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) - -# self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) -# self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) - -# def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): -# self_attn_params, cross_attn_params, ff_params = torch.chunk( -# self.visual_modulation(time_embed).unsqueeze(dim=1), 3, dim=-1 -# ) -# shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.self_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.self_attention(visual_out, rope, sparse_params) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.cross_attention_norm, visual_embed, scale, shift -# ) -# visual_out = self.cross_attention(visual_out, text_embed) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) - -# shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) -# visual_out = apply_scale_shift_norm( -# self.feed_forward_norm, visual_embed, scale, shift -# ) -# visual_out = self.feed_forward(visual_out) -# visual_embed = apply_gate_sum(visual_embed, visual_out, gate) -# return visual_embed - - class Kandinsky5TransformerDecoderBlock(nn.Module): def __init__(self, model_dim, time_dim, ff_dim, head_dim): super().__init__() @@ -724,16 +683,16 @@ def __init__( axes_dims=(16, 24, 24), visual_cond=False, attention_type: str = "regular", - attention_causal: bool = None, # Default for Nabla: false - attention_local: bool = None, # Default for Nabla: false - attention_glob: bool = None, # Default for Nabla: false - attention_window: int = None, # Default for Nabla: 3 - attention_P: float = None, # Default for Nabla: 0.9 - attention_wT: int = None, # Default for Nabla: 11 - attention_wW: int = None, # Default for Nabla: 3 - attention_wH: int = None, # Default for Nabla: 3 - attention_add_sta: bool = None, # Default for Nabla: true - attention_method: str = None, # Default for Nabla: "topcdf" + attention_causal: bool = None, + attention_local: bool = None, + attention_glob: bool = None, + attention_window: int = None, + attention_P: float = None, + attention_wT: int = None, + attention_wW: int = None, + attention_wH: int = None, + attention_add_sta: bool = None, + attention_method: str = None, ): super().__init__() @@ -779,73 +738,6 @@ def __init__( ) self.gradient_checkpointing = False - def prepare_text_embeddings( - self, text_embed, time, pooled_text_embed, x, text_rope_pos - ): - """Prepare text embeddings and related components""" - text_embed = self.text_embeddings(text_embed) - time_embed = self.time_embeddings(time) - time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) - visual_embed = self.visual_embeddings(x) - text_rope = self.text_rope_embeddings(text_rope_pos) - text_rope = text_rope.unsqueeze(dim=0) - return text_embed, time_embed, text_rope, visual_embed - - def prepare_visual_embeddings( - self, visual_embed, visual_rope_pos, scale_factor, sparse_params - ): - """Prepare visual embeddings and related components""" - visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings( - visual_shape, visual_rope_pos, scale_factor - ) - to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten( - visual_embed, visual_rope, visual_shape, block_mask=to_fractal - ) - return visual_embed, visual_shape, to_fractal, visual_rope - - def process_text_transformer_blocks(self, text_embed, time_embed, text_rope): - """Process text through transformer blocks""" - for text_transformer_block in self.text_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - text_embed = self._gradient_checkpointing_func( - text_transformer_block, text_embed, time_embed, text_rope - ) - else: - text_embed = text_transformer_block(text_embed, time_embed, text_rope) - return text_embed - - def process_visual_transformer_blocks( - self, visual_embed, text_embed, time_embed, visual_rope, sparse_params - ): - """Process visual through transformer blocks""" - for visual_transformer_block in self.visual_transformer_blocks: - if torch.is_grad_enabled() and self.gradient_checkpointing: - visual_embed = self._gradient_checkpointing_func( - visual_transformer_block, - visual_embed, - text_embed, - time_embed, - visual_rope, - sparse_params, - ) - else: - visual_embed = visual_transformer_block( - visual_embed, text_embed, time_embed, visual_rope, sparse_params - ) - return visual_embed - - def prepare_output( - self, visual_embed, visual_shape, to_fractal, text_embed, time_embed - ): - """Prepare the final output""" - visual_embed = fractal_unflatten( - visual_embed, visual_shape, block_mask=to_fractal - ) - x = self.out_layer(visual_embed, text_embed, time_embed) - return x - def forward( self, hidden_states: torch.FloatTensor, # x @@ -881,32 +773,49 @@ def forward( time = timestep pooled_text_embed = pooled_projections - # Prepare text embeddings and related components - text_embed, time_embed, text_rope, visual_embed = self.prepare_text_embeddings( - text_embed, time, pooled_text_embed, x, text_rope_pos - ) + text_embed = self.text_embeddings(text_embed) + time_embed = self.time_embeddings(time) + time_embed = time_embed + self.pooled_text_embeddings(pooled_text_embed) + visual_embed = self.visual_embeddings(x) + text_rope = self.text_rope_embeddings(text_rope_pos) + text_rope = text_rope.unsqueeze(dim=0) - # Process text through transformer blocks - text_embed = self.process_text_transformer_blocks( - text_embed, time_embed, text_rope - ) + for text_transformer_block in self.text_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + text_embed = self._gradient_checkpointing_func( + text_transformer_block, text_embed, time_embed, text_rope + ) + else: + text_embed = text_transformer_block(text_embed, time_embed, text_rope) - # Prepare visual embeddings and related components - visual_embed, visual_shape, to_fractal, visual_rope = ( - self.prepare_visual_embeddings( - visual_embed, visual_rope_pos, scale_factor, sparse_params - ) + visual_shape = visual_embed.shape[:-1] + visual_rope = self.visual_rope_embeddings( + visual_shape, visual_rope_pos, scale_factor ) - - # Process visual through transformer blocks - visual_embed = self.process_visual_transformer_blocks( - visual_embed, text_embed, time_embed, visual_rope, sparse_params + to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False + visual_embed, visual_rope = fractal_flatten( + visual_embed, visual_rope, visual_shape, block_mask=to_fractal ) - # Prepare final output - x = self.prepare_output( - visual_embed, visual_shape, to_fractal, text_embed, time_embed + for visual_transformer_block in self.visual_transformer_blocks: + if torch.is_grad_enabled() and self.gradient_checkpointing: + visual_embed = self._gradient_checkpointing_func( + visual_transformer_block, + visual_embed, + text_embed, + time_embed, + visual_rope, + sparse_params, + ) + else: + visual_embed = visual_transformer_block( + visual_embed, text_embed, time_embed, visual_rope, sparse_params + ) + + visual_embed = fractal_unflatten( + visual_embed, visual_shape, block_mask=to_fractal ) + x = self.out_layer(visual_embed, text_embed, time_embed) if not return_dict: return x From 3ffdf7f113e442c68d65da5033e31a195f7a1be7 Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 10:32:47 +0000 Subject: [PATCH 32/70] separated prompt encoding --- .../kandinsky5/pipeline_kandinsky.py | 153 +++++++----------- 1 file changed, 56 insertions(+), 97 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index b1f7924e9b9f..2ff0c1d45d81 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -359,124 +359,64 @@ def _encode_prompt_clip( def encode_prompt( self, prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - do_classifier_free_guidance: bool = True, num_videos_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" - Encodes the prompt into text encoder hidden states. + Encodes a single prompt (positive or negative) into text encoder hidden states. This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for video generation. - + Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): - Whether to use classifier free guidance or not. + prompt (`str` or `List[str]`): + Prompt to be encoded. num_videos_per_prompt (`int`, *optional*, defaults to 1): - Number of videos that should be generated per prompt. - prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.Tensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. + Number of videos to generate per prompt. max_sequence_length (`int`, *optional*, defaults to 512): Maximum sequence length for text encoding. - device: (`torch.device`, *optional*): - torch device - dtype: (`torch.dtype`, *optional*): - torch dtype - + device (`torch.device`, *optional*): + Torch device. + dtype (`torch.dtype`, *optional*): + Torch dtype. + Returns: - Tuple: Contains prompt embeddings, negative prompt embeddings, and sequence length information + Tuple[Dict[str, torch.Tensor], torch.Tensor]: + - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings """ device = device or self._execution_device + dtype = dtype or self.text_encoder.dtype - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - prompt = [prompt] - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt = [prompt_clean(p) for p in prompt] - - prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - prompt_embeds_clip = self._encode_prompt_clip( - prompt=prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = prompt_embeds + batch_size = len(prompt) - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" - negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + prompt = [prompt_clean(p) for p in prompt] - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - - negative_prompt = [prompt_clean(p) for p in negative_prompt] + # Encode with Qwen2.5-VL + prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + dtype=dtype, + ) - negative_prompt_embeds_qwen, negative_cu_seqlens = self._encode_prompt_qwen( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - dtype=dtype, - ) - negative_prompt_embeds_clip = self._encode_prompt_clip( - prompt=negative_prompt, - device=device, - num_videos_per_prompt=num_videos_per_prompt, - dtype=dtype, - ) - else: - negative_prompt_embeds_qwen = None - negative_prompt_embeds_clip = None - negative_cu_seqlens = None + # Encode with CLIP + prompt_embeds_clip = self._encode_prompt_clip( + prompt=prompt, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + dtype=dtype, + ) prompt_embeds_dict = { "text_embeds": prompt_embeds_qwen, "pooled_embed": prompt_embeds_clip, } - negative_prompt_embeds_dict = { - "text_embeds": negative_prompt_embeds_qwen, - "pooled_embed": negative_prompt_embeds_clip, - } if do_classifier_free_guidance else None - return prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens + return prompt_embeds_dict, prompt_cu_seqlens def check_inputs( self, @@ -722,24 +662,43 @@ def __call__( # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 + prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds_dict, negative_prompt_embeds_dict, prompt_cu_seqlens, negative_cu_seqlens = self.encode_prompt( + prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt( prompt=prompt, - negative_prompt=negative_prompt, - do_classifier_free_guidance=self.do_classifier_free_guidance, num_videos_per_prompt=num_videos_per_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) + negative_prompt_embeds_dict = None + negative_cu_seqlens = None + + if self.do_classifier_free_guidance: + if negative_prompt is None: + negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] + elif len(negative_prompt) != len(prompt): + raise ValueError( + f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." + ) + + negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + num_videos_per_prompt=num_videos_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps From 9f52335290e0e2076166dcc35180557527a7d5eb Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:47:38 +0300 Subject: [PATCH 33/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45e4238cfb51..38cc5156bc49 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -57,7 +57,6 @@ def freeze(model): return model -@torch.autocast(device_type="cuda", enabled=False) def get_freqs(dim, max_period=10000.0): freqs = torch.exp( -math.log(max_period) From cc46e2d2defbb922b7e0ef8e1f014e9361850b5c Mon Sep 17 00:00:00 2001 From: leffff Date: Thu, 16 Oct 2025 22:48:09 +0000 Subject: [PATCH 34/70] refactoring --- src/diffusers/models/transformers/transformer_kandinsky.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45e4238cfb51..d08f2a968e15 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -186,10 +186,7 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 self.model_dim = model_dim - self.max_period = max_period - self.register_buffer( - "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) + self.max_period = max_period self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() From 9672c6bd6f70a28cca896025fc57e89b72117838 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:49:19 +0300 Subject: [PATCH 35/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 38cc5156bc49..488c44189202 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -173,7 +173,6 @@ def nablaT_v2( ) -@torch.autocast(device_type="cuda", enabled=False) def apply_rotary(x, rope): x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) x_out = (rope * x_).sum(dim=-1) From 900feba4fe196b911344c779cc9c951dfbc067ca Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 17 Oct 2025 14:38:42 +0000 Subject: [PATCH 36/70] refactoring acording to https://github.com/huggingface/diffusers/commit/acabbc0033d4b4933fc651766a4aa026db2e6dc1 --- .../transformers/transformer_kandinsky.py | 318 ++++++------------ 1 file changed, 104 insertions(+), 214 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index f88429fa1714..7a4f85c744ec 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -19,10 +19,6 @@ import torch.nn as nn import torch.nn.functional as F from torch import BoolTensor, IntTensor, Tensor, nn -from torch.nn.attention.flex_attention import ( - BlockMask, - flex_attention, -) from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin @@ -34,7 +30,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward +from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin from ..cache_utils import CacheMixin from ..embeddings import ( TimestepEmbedding, @@ -43,6 +39,7 @@ from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import FP32LayerNorm +from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN logger = logging.get_logger(__name__) @@ -148,7 +145,15 @@ def nablaT_v2( k: Tensor, sta: Tensor, thr: float = 0.9, -) -> BlockMask: +): + if _CAN_USE_FLEX_ATTN: + from torch.nn.attention.flex_attention import BlockMask + else: + raise ValueError("Nabla attention is not supported with this version of PyTorch") + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + # Map estimation B, h, S, D = q.shape s1 = S // 64 @@ -173,18 +178,15 @@ def nablaT_v2( ) -def apply_rotary(x, rope): - x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) - x_out = (rope * x_).sum(dim=-1) - return x_out.reshape(*x.shape).to(torch.bfloat16) - - class Kandinsky5TimeEmbeddings(nn.Module): def __init__(self, model_dim, time_dim, max_period=10000.0): super().__init__() assert model_dim % 2 == 0 self.model_dim = model_dim - self.max_period = max_period + self.max_period = max_period + self.register_buffer( + "freqs", get_freqs(model_dim // 2, max_period), persistent=False + ) self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() @@ -307,184 +309,82 @@ def forward(self, x): return self.out_layer(self.activation(x)) -class Kandinsky5SDPAAttentionProcessor(nn.Module): - """Custom attention processor for standard SDPA attention""" - - def __call__( - self, - attn, - query, - key, - value, - **kwargs, - ): - # Process attention with the given query, key, value tensors - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - out = F.scaled_dot_product_attention(query, key, value).transpose(1, 2).contiguous().flatten(-2, -1) - - return out - - -class Kandinsky5NablaAttentionProcessor(nn.Module): - """Custom attention processor for Nabla attention""" - - @torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) - def __call__( - self, - attn, - query, - key, - value, - sparse_params=None, - **kwargs, - ): - if sparse_params is None: - raise ValueError("sparse_params is required for Nabla attention") - - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - - block_mask = nablaT_v2( - query, - key, - sparse_params["sta_mask"], - thr=sparse_params["P"], - ) - out = ( - flex_attention(query, key, value, block_mask=block_mask) - .transpose(1, 2) - .contiguous() - ) - out = out.flatten(-2, -1) - return out - - -class Kandinsky5MultiheadSelfAttentionEnc(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) - - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) +class Kandinsky5AttnProcessor: - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) + _attention_backend = None + _parallel_config = None - return query, key, value + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - def scaled_dot_product_attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) + def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): + # query, key, value = self.get_qkv(x) + query = attn.to_query(hidden_states) - def out_l(self, x): - return self.out_layer(x) + if encoder_hidden_states is not None: + key = attn.to_key(encoder_hidden_states) + value = attn.to_value(encoder_hidden_states) - def forward(self, x, rope): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) + shape, cond_shape = query.shape[:-1], key.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*cond_shape, attn.num_heads, -1) + value = value.reshape(*cond_shape, attn.num_heads, -1) + + else: + key = attn.to_key(hidden_states) + value = attn.to_value(hidden_states) - out = self.scaled_dot_product_attention(query, key, value) + shape = query.shape[:-1] + query = query.reshape(*shape, attn.num_heads, -1) + key = key.reshape(*shape, attn.num_heads, -1) + value = value.reshape(*shape, attn.num_heads, -1) - out = self.out_l(out) - return out + # query, key = self.norm_qk(query, key) + query = attn.query_norm(query.float()).type_as(query) + key = attn.key_norm(key.float()).type_as(key) + def apply_rotary(x, rope): + x_ = x.reshape(*x.shape[:-1], -1, 1, 2).to(torch.float32) + x_out = (rope * x_).sum(dim=-1) + return x_out.reshape(*x.shape).to(torch.bfloat16) -class Kandinsky5MultiheadSelfAttentionDec(nn.Module): - def __init__(self, num_channels, head_dim): - super().__init__() - assert num_channels % head_dim == 0 - self.num_heads = num_channels // head_dim - - self.to_query = nn.Linear(num_channels, num_channels, bias=True) - self.to_key = nn.Linear(num_channels, num_channels, bias=True) - self.to_value = nn.Linear(num_channels, num_channels, bias=True) - self.query_norm = nn.RMSNorm(head_dim) - self.key_norm = nn.RMSNorm(head_dim) - - self.out_layer = nn.Linear(num_channels, num_channels, bias=True) - - # Initialize attention processors - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - self.nabla_processor = Kandinsky5NablaAttentionProcessor() - - def get_qkv(self, x): - query = self.to_query(x) - key = self.to_key(x) - value = self.to_value(x) - - shape = query.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*shape, self.num_heads, -1) - value = value.reshape(*shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def nabla(self, query, key, value, sparse_params=None): - # Use the processor - return self.nabla_processor( - attn=self, - query=query, - key=key, - value=value, - sparse_params=sparse_params, - **{}, - ) - - def out_l(self, x): - return self.out_layer(x) - - def forward(self, x, rope, sparse_params=None): - query, key, value = self.get_qkv(x) - query, key = self.norm_qk(query, key) - query = apply_rotary(query, rope).type_as(query) - key = apply_rotary(key, rope).type_as(key) + if rotary_emb is not None: + query = apply_rotary(query, rotary_emb).type_as(query) + key = apply_rotary(key, rotary_emb).type_as(key) if sparse_params is not None: - out = self.nabla(query, key, value, sparse_params=sparse_params) + attn_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) else: - out = self.attention(query, key, value) + attn_mask = None + + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attn_mask, + backend=self._attention_backend, + parallel_config=self._parallel_config, + ) + hidden_states = hidden_states.flatten(-2, -1) - out = self.out_l(out) - return out + attn_out = attn.out_layer(hidden_states) + return attn_out -class Kandinsky5MultiheadCrossAttention(nn.Module): - def __init__(self, num_channels, head_dim): +class Kandinsky5Attention(nn.Module, AttentionModuleMixin): + + _default_processor_cls = Kandinsky5AttnProcessor + _available_processors = [ + Kandinsky5AttnProcessor, + ] + def __init__(self, num_channels, head_dim, processor=None): super().__init__() assert num_channels % head_dim == 0 self.num_heads = num_channels // head_dim @@ -496,43 +396,33 @@ def __init__(self, num_channels, head_dim): self.key_norm = nn.RMSNorm(head_dim) self.out_layer = nn.Linear(num_channels, num_channels, bias=True) + if processor is None: + processor = self._default_processor_cls() + self.set_processor(processor) - # Initialize attention processor - self.sdpa_processor = Kandinsky5SDPAAttentionProcessor() - - def get_qkv(self, x, cond): - query = self.to_query(x) - key = self.to_key(cond) - value = self.to_value(cond) - - shape, cond_shape = query.shape[:-1], key.shape[:-1] - query = query.reshape(*shape, self.num_heads, -1) - key = key.reshape(*cond_shape, self.num_heads, -1) - value = value.reshape(*cond_shape, self.num_heads, -1) - - return query, key, value - - def norm_qk(self, q, k): - q = self.query_norm(q.float()).type_as(q) - k = self.key_norm(k.float()).type_as(k) - return q, k - - def attention(self, query, key, value): - # Use the processor - return self.sdpa_processor(attn=self, query=query, key=key, value=value, **{}) - - def out_l(self, x): - return self.out_layer(x) + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + sparse_params: Optional[torch.Tensor] = None, + rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> torch.Tensor: - def forward(self, x, cond): - query, key, value = self.get_qkv(x, cond) - query, key = self.norm_qk(query, key) + import inspect - out = self.attention(query, key, value) - out = self.out_l(out) - return out + attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) + quiet_attn_parameters = {} + unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] + if len(unused_kwargs) > 0: + logger.warning( + f"attention_processor_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." + ) + kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} + return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs) + class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() @@ -589,7 +479,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.text_modulation = Kandinsky5Modulation(time_dim, model_dim, 6) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionEnc(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -600,7 +490,7 @@ def forward(self, x, time_embed, rope): ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) - out = self.self_attention(out, rope) + out = self.self_attention(out, rotary_emb=rope) x = (x.float() + gate.float() * out.float()).type_as(x) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) @@ -617,10 +507,10 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.visual_modulation = Kandinsky5Modulation(time_dim, model_dim, 9) self.self_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.self_attention = Kandinsky5MultiheadSelfAttentionDec(model_dim, head_dim) + self.self_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.cross_attention_norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.cross_attention = Kandinsky5MultiheadCrossAttention(model_dim, head_dim) + self.cross_attention = Kandinsky5Attention(model_dim, head_dim, processor=Kandinsky5AttnProcessor()) self.feed_forward_norm = nn.LayerNorm(model_dim, elementwise_affine=False) self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) @@ -632,12 +522,12 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.self_attention(visual_out, rope, sparse_params) + visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) - visual_out = self.cross_attention(visual_out, text_embed) + visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) @@ -815,4 +705,4 @@ def forward( if not return_dict: return x - return Transformer2DModelOutput(sample=x) + return Transformer2DModelOutput(sample=x) \ No newline at end of file From 226bbf8ee1c3c1ddc408aaa6664519c36c995176 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:36:09 +0300 Subject: [PATCH 37/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 7a4f85c744ec..7569b8cd8006 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -44,8 +44,6 @@ logger = logging.get_logger(__name__) -def exist(item): - return item is not None def freeze(model): From 9504fb0d63f9ddd59c01e290c9d71304981bf7f5 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:36:32 +0300 Subject: [PATCH 38/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 7569b8cd8006..d85b411caf07 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -46,10 +46,6 @@ -def freeze(model): - for p in model.parameters(): - p.requires_grad = False - return model def get_freqs(dim, max_period=10000.0): From f0eca0849b68d61b7cf98b54e4a95ec9e92157a4 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:37:35 +0300 Subject: [PATCH 39/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index d85b411caf07..03b40e78de55 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -178,9 +178,6 @@ def __init__(self, model_dim, time_dim, max_period=10000.0): assert model_dim % 2 == 0 self.model_dim = model_dim self.max_period = max_period - self.register_buffer( - "freqs", get_freqs(model_dim // 2, max_period), persistent=False - ) self.freqs = get_freqs(self.model_dim // 2, self.max_period) self.in_layer = nn.Linear(model_dim, time_dim, bias=True) self.activation = nn.SiLU() From cc74c1e46e47d2dbd518c40d636e21e20d3bfbc1 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:21 +0300 Subject: [PATCH 40/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 03b40e78de55..45bc4849749a 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -237,7 +237,6 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): pos = torch.arange(max_pos, dtype=freq.dtype) self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) - @torch.autocast(device_type="cuda", enabled=False) def forward(self, pos): args = self.args[pos] cosine = torch.cos(args) From cb915d71adb2bcfef1a30b91774ce19542923c0a Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:33 +0300 Subject: [PATCH 41/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 45bc4849749a..6b9f60432503 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -258,7 +258,6 @@ def __init__(self, axes_dims, max_pos=(128, 128, 128), max_period=10000.0): pos = torch.arange(ax_max_pos, dtype=freq.dtype) self.register_buffer(f"args_{i}", torch.outer(pos, freq), persistent=False) - @torch.autocast(device_type="cuda", enabled=False) def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): batch_size, duration, height, width = shape args_t = self.args_0[pos[0]] / scale_factor[0] From 9aa3c2eb20d4e16b3c2db2caef458acaaac32fbf Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:41:56 +0300 Subject: [PATCH 42/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 6b9f60432503..490b64ffdfd1 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -614,7 +614,7 @@ def __init__( def forward( self, - hidden_states: torch.FloatTensor, # x + hidden_states: torch.Tensor, # x encoder_hidden_states: torch.FloatTensor, # text_embed timestep: Union[torch.Tensor, float, int], # time pooled_projections: torch.FloatTensor, # pooled_text_embed From feac8f095ff285bbe9bfd23989567ab27166b2ad Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:45:30 +0300 Subject: [PATCH 43/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 490b64ffdfd1..2c12b0e90b65 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -615,7 +615,7 @@ def __init__( def forward( self, hidden_states: torch.Tensor, # x - encoder_hidden_states: torch.FloatTensor, # text_embed + encoder_hidden_states: torch.Tensor, # text_embed timestep: Union[torch.Tensor, float, int], # time pooled_projections: torch.FloatTensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], From d3b959750bc3e39e44bcd6910504a9e1b23260bd Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:46:34 +0300 Subject: [PATCH 44/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 2c12b0e90b65..e674a8ba1f2a 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -616,7 +616,7 @@ def forward( self, hidden_states: torch.Tensor, # x encoder_hidden_states: torch.Tensor, # text_embed - timestep: Union[torch.Tensor, float, int], # time + timestep: torch.Tensor, # time pooled_projections: torch.FloatTensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], text_rope_pos: torch.LongTensor, From 693b9aa9c2880d9d570d44996bcfcafd9be9cf01 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:47:03 +0300 Subject: [PATCH 45/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index e674a8ba1f2a..ad39a9bed63f 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -617,7 +617,7 @@ def forward( hidden_states: torch.Tensor, # x encoder_hidden_states: torch.Tensor, # text_embed timestep: torch.Tensor, # time - pooled_projections: torch.FloatTensor, # pooled_text_embed + pooled_projections: torch.Tensor, # pooled_text_embed visual_rope_pos: Tuple[int, int, int], text_rope_pos: torch.LongTensor, scale_factor: Tuple[float, float, float] = (1.0, 1.0, 1.0), From e2ed6ec961d8d2a251d71de5345a5012fd302a17 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:47:57 +0300 Subject: [PATCH 46/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 2ff0c1d45d81..5369bc579b67 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -416,7 +416,7 @@ def encode_prompt( "pooled_embed": prompt_embeds_clip, } - return prompt_embeds_dict, prompt_cu_seqlens + return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens def check_inputs( self, From 2925447e3339ca3477144f3814106e87952a0c4a Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:48:35 +0300 Subject: [PATCH 47/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5369bc579b67..988cce6b5e79 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -411,10 +411,6 @@ def encode_prompt( dtype=dtype, ) - prompt_embeds_dict = { - "text_embeds": prompt_embeds_qwen, - "pooled_embed": prompt_embeds_clip, - } return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens From b02ad82513971dfe14c57b9782d0218e9364df97 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:48:55 +0300 Subject: [PATCH 48/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 988cce6b5e79..c1c510dc12c6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -398,7 +398,6 @@ def encode_prompt( prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, - num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=max_sequence_length, dtype=dtype, ) From dc67c2bb4bb1367c7dc3fd4a9cdc93b452e531e5 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:49:19 +0300 Subject: [PATCH 49/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index c1c510dc12c6..420748873cf3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -406,7 +406,6 @@ def encode_prompt( prompt_embeds_clip = self._encode_prompt_clip( prompt=prompt, device=device, - num_videos_per_prompt=num_videos_per_prompt, dtype=dtype, ) From d0fc426a744172595f194d01687ca1bc54300bd1 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:49:48 +0300 Subject: [PATCH 50/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 420748873cf3..f879f9dc5d09 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -305,7 +305,6 @@ def _encode_prompt_qwen( output_hidden_states=True, )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] - batch_size = len(prompt) attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) From 222ba4ca4dd2093696937252e21f11c6b04410a6 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:50:06 +0300 Subject: [PATCH 51/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index f879f9dc5d09..1e5a5ac58fa3 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -264,7 +264,6 @@ def _encode_prompt_qwen( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, max_sequence_length: int = 256, dtype: Optional[torch.dtype] = None, ): From 3a495058b05dacc7bc2f4eb8982430e4864e8628 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:50:48 +0300 Subject: [PATCH 52/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 1e5a5ac58fa3..6adc611bdc11 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -308,7 +308,6 @@ def _encode_prompt_qwen( attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) - embeds = embeds.repeat_interleave(num_videos_per_prompt, dim=0) return embeds.to(dtype), cu_seqlens From 1e12017008ea693823d08fd9b54a1d54b7f1db56 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:08 +0300 Subject: [PATCH 53/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 6adc611bdc11..b700df0e485e 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -315,7 +315,6 @@ def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, - num_videos_per_prompt: int = 1, dtype: Optional[torch.dtype] = None, ): """ From 5a300798efeee38600c9101882144e3d8ff53f16 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:40 +0300 Subject: [PATCH 54/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index b700df0e485e..4b5c19a9e3cf 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -346,9 +346,6 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - batch_size = len(prompt) - pooled_embed = pooled_embed.repeat(1, num_videos_per_prompt, 1) - pooled_embed = pooled_embed.view(batch_size * num_videos_per_prompt, -1) return pooled_embed.to(dtype) From 0d96ecfdd53f209bedd29b1df6e661eb03cd8dea Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:51:57 +0300 Subject: [PATCH 55/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 4b5c19a9e3cf..4c880e079a55 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -401,7 +401,11 @@ def encode_prompt( device=device, dtype=dtype, ) - + prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) + + prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) + prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens From aadafc14d20117db514fd70ddadc9d4fb5c5bf05 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:15 +0300 Subject: [PATCH 56/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 4c880e079a55..67a49ecaa5e6 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -668,8 +668,6 @@ def __call__( dtype=dtype, ) - negative_prompt_embeds_dict = None - negative_cu_seqlens = None if self.do_classifier_free_guidance: if negative_prompt is None: From 54cf03c7139c26670edd15a781c5e98f6c56ad88 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:29 +0300 Subject: [PATCH 57/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 67a49ecaa5e6..a7b8bd117c1a 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -563,7 +563,7 @@ def __call__( num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_qwen: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, From 22c503fb84b60b2c6eed777c3b4f23ee82ea5936 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:52:55 +0300 Subject: [PATCH 58/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a7b8bd117c1a..0ba0bed9e102 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -564,7 +564,11 @@ def __call__( generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds_qwen: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds_clip: Optional[torch.Tensor] = None, + negative_prompt_embeds_qwen: Optional[torch.Tensor] = None, + negative_prompt_embeds_clip: Optional[torch.Tensor] = None, + prompt_cu_seqlens: Optional[torch.Tensor] = None, + negative_prompt_cu_seqlens: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ From 211d3dd3407a413ce414646b0154781a817d9fba Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:53:10 +0300 Subject: [PATCH 59/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 0ba0bed9e102..fcd6bc301ea9 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -664,13 +664,13 @@ def __call__( batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] # 3. Encode input prompt - prompt_embeds_dict, prompt_cu_seqlens = self.encode_prompt( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) if self.do_classifier_free_guidance: From 70cfb9e984344f72f63834670f05a5a328bfb565 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:54:16 +0300 Subject: [PATCH 60/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- .../pipelines/kandinsky5/pipeline_kandinsky.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index fcd6bc301ea9..5ab69420962d 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -684,13 +684,13 @@ def __call__( f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." ) - negative_prompt_embeds_dict, negative_cu_seqlens = self.encode_prompt( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + if negative_prompt_embeds_qwen is None: + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) From 6e83133e699855c62824f34cac0dbd8ff86e6f0b Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:54:47 +0300 Subject: [PATCH 61/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 5ab69420962d..1cbf5f84fb94 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -743,7 +743,7 @@ def __call__( # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_dict["text_embeds"].to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, From 7ad87f3554e1d64d0fcf510698552a7408b810bb Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:55:06 +0300 Subject: [PATCH 62/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 1cbf5f84fb94..a863b49a8f71 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -744,7 +744,7 @@ def __call__( pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_dict["pooled_embed"].to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, From bf229afa110338bfbd9dd58460605c6670152c02 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:56:04 +0300 Subject: [PATCH 63/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index a863b49a8f71..c12cee5b8027 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -756,7 +756,7 @@ def __call__( if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_dict["text_embeds"].to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, From 06afd9ba19ab5de8a2bfbfb1ff33f6fb1c845c02 Mon Sep 17 00:00:00 2001 From: Lev Novitskiy <57654885+leffff@users.noreply.github.com> Date: Fri, 17 Oct 2025 22:57:04 +0300 Subject: [PATCH 64/70] Update src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py Co-authored-by: YiYi Xu --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index c12cee5b8027..fe5c59cc247b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -757,7 +757,7 @@ def __call__( uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_dict["pooled_embed"].to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, From e1a635ec7fb0e2b7e29fb9c7e1629ae0fd2ffdea Mon Sep 17 00:00:00 2001 From: leffff Date: Fri, 17 Oct 2025 20:28:06 +0000 Subject: [PATCH 65/70] fixed --- .../kandinsky5/pipeline_kandinsky.py | 175 ++++++++++++++---- 1 file changed, 137 insertions(+), 38 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index fe5c59cc247b..ff6b00d5fb26 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -349,6 +349,66 @@ def _encode_prompt_clip( return pooled_embed.to(dtype) +# def encode_prompt( +# self, +# prompt: Union[str, List[str]], +# num_videos_per_prompt: int = 1, +# max_sequence_length: int = 512, +# device: Optional[torch.device] = None, +# dtype: Optional[torch.dtype] = None, +# ): +# r""" +# Encodes a single prompt (positive or negative) into text encoder hidden states. + +# This method combines embeddings from both Qwen2.5-VL and CLIP text encoders +# to create comprehensive text representations for video generation. + +# Args: +# prompt (`str` or `List[str]`): +# Prompt to be encoded. +# num_videos_per_prompt (`int`, *optional*, defaults to 1): +# Number of videos to generate per prompt. +# max_sequence_length (`int`, *optional*, defaults to 512): +# Maximum sequence length for text encoding. +# device (`torch.device`, *optional*): +# Torch device. +# dtype (`torch.dtype`, *optional*): +# Torch dtype. + +# Returns: +# Tuple[Dict[str, torch.Tensor], torch.Tensor]: +# - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) +# - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings +# """ +# device = device or self._execution_device +# dtype = dtype or self.text_encoder.dtype + +# batch_size = len(prompt) + +# prompt = [prompt_clean(p) for p in prompt] + +# # Encode with Qwen2.5-VL +# prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( +# prompt=prompt, +# device=device, +# max_sequence_length=max_sequence_length, +# dtype=dtype, +# ) + +# # Encode with CLIP +# prompt_embeds_clip = self._encode_prompt_clip( +# prompt=prompt, +# device=device, +# dtype=dtype, +# ) +# prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) +# prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) + +# prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) +# prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) + +# return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens + def encode_prompt( self, prompt: Union[str, List[str]], @@ -376,9 +436,10 @@ def encode_prompt( Torch dtype. Returns: - Tuple[Dict[str, torch.Tensor], torch.Tensor]: - - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) - - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) + - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,) """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -394,6 +455,7 @@ def encode_prompt( max_sequence_length=max_sequence_length, dtype=dtype, ) + # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] # Encode with CLIP prompt_embeds_clip = self._encode_prompt_clip( @@ -401,13 +463,30 @@ def encode_prompt( device=device, dtype=dtype, ) - prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) - prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) - - prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) + # prompt_embeds_clip shape: [batch_size, clip_embed_dim] + + # Repeat embeddings for num_videos_per_prompt + # Qwen embeddings: repeat sequence for each video, then reshape + prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]) + + # CLIP embeddings: repeat for each video + prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) # [batch_size, num_videos_per_prompt, clip_embed_dim] + # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) - return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens + # Repeat cumulative sequence lengths for num_videos_per_prompt + # Original cu_seqlens: [0, len1, len1+len2, ...] + # Need to repeat the differences and reconstruct for repeated prompts + # Original differences (lengths) for each prompt in the batch + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + # Repeat the lengths for num_videos_per_prompt + repeated_lengths = original_lengths.repeat_interleave(num_videos_per_prompt) # [len1, len1, ..., len2, len2, ...] + # Reconstruct the cumulative lengths + repeated_cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]) + + return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens def check_inputs( self, @@ -415,22 +494,30 @@ def check_inputs( negative_prompt, height, width, - prompt_embeds=None, - negative_prompt_embeds=None, + prompt_embeds_qwen=None, + prompt_embeds_clip=None, + negative_prompt_embeds_qwen=None, + negative_prompt_embeds_clip=None, + prompt_cu_seqlens=None, + negative_prompt_cu_seqlens=None, callback_on_step_end_tensor_inputs=None, ): """ Validate input parameters for the pipeline. - + Args: prompt: Input prompt negative_prompt: Negative prompt for guidance height: Video height width: Video width - prompt_embeds: Pre-computed prompt embeddings - negative_prompt_embeds: Pre-computed negative prompt embeddings + prompt_embeds_qwen: Pre-computed Qwen prompt embeddings + prompt_embeds_clip: Pre-computed CLIP prompt embeddings + negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings + negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings + prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt + negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt callback_on_step_end_tensor_inputs: Callback tensor inputs - + Raises: ValueError: If inputs are invalid """ @@ -444,23 +531,32 @@ def check_inputs( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: + # Check for consistency within positive prompt embeddings and sequence lengths + if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: + if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: + raise ValueError( + f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + f"all three must be provided." + ) + + # Check for consistency within negative prompt embeddings and sequence lengths + if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None: + if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None: + raise ValueError( + f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + f"all three must be provided." + ) + + # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) + if prompt is None and prompt_embeds_qwen is None: raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + + # Validate types for prompt and negative_prompt if provided + if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - elif negative_prompt is not None and ( + if negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") @@ -632,13 +728,17 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( - prompt, - negative_prompt, - height, - width, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, + prompt=prompt, + negative_prompt=negative_prompt, + height=height, + width=width, + prompt_embeds_qwen=prompt_embeds_qwen, + prompt_embeds_clip=prompt_embeds_clip, + negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, + negative_prompt_embeds_clip=negative_prompt_embeds_clip, + prompt_cu_seqlens=prompt_cu_seqlens, + negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, ) if num_frames % self.vae_scale_factor_temporal != 1: @@ -739,7 +839,7 @@ def __call__( continue timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) - + # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), @@ -753,7 +853,7 @@ def __call__( return_dict=True ).sample - if self.do_classifier_free_guidance and negative_prompt_embeds_dict is not None: + if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), @@ -769,7 +869,6 @@ def __call__( pred_velocity = uncond_pred_velocity + guidance_scale * ( pred_velocity - uncond_pred_velocity ) - # Compute previous sample using the scheduler latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False From 1bf19f0904d9faa6849c75f0a4a6f9441643be66 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:20:06 +0200 Subject: [PATCH 66/70] style +copies --- src/diffusers/__init__.py | 8 +- src/diffusers/loaders/__init__.py | 2 +- src/diffusers/loaders/lora_pipeline.py | 19 +- src/diffusers/models/__init__.py | 4 +- src/diffusers/models/transformers/__init__.py | 2 +- .../transformers/transformer_kandinsky.py | 138 +++---- src/diffusers/pipelines/__init__.py | 2 +- .../pipelines/kandinsky5/__init__.py | 2 +- .../kandinsky5/pipeline_kandinsky.py | 348 ++++++++---------- src/diffusers/utils/dummy_pt_objects.py | 15 + .../dummy_torch_and_transformers_objects.py | 15 + 11 files changed, 258 insertions(+), 297 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 54e33d69514f..aa500b149441 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -220,6 +220,7 @@ "HunyuanVideoTransformer3DModel", "I2VGenXLUNet", "Kandinsky3UNet", + "Kandinsky5Transformer3DModel", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", "Lumina2Transformer2DModel", @@ -260,7 +261,6 @@ "VQModel", "WanTransformer3DModel", "WanVACETransformer3DModel", - "Kandinsky5Transformer3DModel", "attention_backend", ] ) @@ -475,6 +475,7 @@ "ImageTextPipelineOutput", "Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline", + "Kandinsky5T2VPipeline", "KandinskyCombinedPipeline", "KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgPipeline", @@ -623,7 +624,6 @@ "WanPipeline", "WanVACEPipeline", "WanVideoToVideoPipeline", - "Kandinsky5T2VPipeline", "WuerstchenCombinedPipeline", "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", @@ -914,6 +914,7 @@ HunyuanVideoTransformer3DModel, I2VGenXLUNet, Kandinsky3UNet, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -953,7 +954,6 @@ VQModel, WanTransformer3DModel, WanVACETransformer3DModel, - Kandinsky5Transformer3DModel, attention_backend, ) from .modular_pipelines import ComponentsManager, ComponentSpec, ModularPipeline, ModularPipelineBlocks @@ -1139,6 +1139,7 @@ ImageTextPipelineOutput, Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, + Kandinsky5T2VPipeline, KandinskyCombinedPipeline, KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgPipeline, @@ -1286,7 +1287,6 @@ WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline, - Kandinsky5T2VPipeline, WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline, diff --git a/src/diffusers/loaders/__init__.py b/src/diffusers/loaders/__init__.py index 6a48ac1b0deb..48507aae038c 100644 --- a/src/diffusers/loaders/__init__.py +++ b/src/diffusers/loaders/__init__.py @@ -116,6 +116,7 @@ def text_encoder_attn_modules(text_encoder): FluxLoraLoaderMixin, HiDreamImageLoraLoaderMixin, HunyuanVideoLoraLoaderMixin, + KandinskyLoraLoaderMixin, LoraLoaderMixin, LTXVideoLoraLoaderMixin, Lumina2LoraLoaderMixin, @@ -127,7 +128,6 @@ def text_encoder_attn_modules(text_encoder): StableDiffusionLoraLoaderMixin, StableDiffusionXLLoraLoaderMixin, WanLoraLoaderMixin, - KandinskyLoraLoaderMixin ) from .single_file import FromSingleFileMixin from .textual_inversion import TextualInversionLoaderMixin diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index ea1b92c68b59..2bb6c0ea026e 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -3638,7 +3638,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): """ super().unfuse_lora(components=components, **kwargs) - + class KandinskyLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`Kandinsky5Transformer3DModel`], @@ -3662,7 +3662,8 @@ def lora_state_dict( Can be either: - A string, the *model id* of a pretrained model hosted on the Hub. - A path to a *directory* containing the model weights. - - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). + - A [torch state + dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached. @@ -3737,7 +3738,7 @@ def load_lora_weights( ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` - + Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. @@ -3746,7 +3747,8 @@ def load_lora_weights( hotswap (`bool`, *optional*): Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place. low_cpu_mem_usage (`bool`, *optional*): - Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. + Speed up model loading by only loading the pretrained LoRA weights and not initializing the random + weights. kwargs (`dict`, *optional*): See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`]. """ @@ -3827,7 +3829,6 @@ def load_lora_into_transformer( hotswap=hotswap, ) - @classmethod def save_lora_weights( cls, @@ -3864,9 +3865,7 @@ def save_lora_weights( lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata if not lora_layers: - raise ValueError( - "You must pass at least one of `transformer_lora_layers`" - ) + raise ValueError("You must pass at least one of `transformer_lora_layers`") cls._save_lora_weights( save_directory=save_directory, @@ -3923,7 +3922,7 @@ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. """ super().unfuse_lora(components=components, **kwargs) - + class WanLoraLoaderMixin(LoraBaseMixin): r""" @@ -5088,4 +5087,4 @@ class LoraLoaderMixin(StableDiffusionLoraLoaderMixin): def __init__(self, *args, **kwargs): deprecation_message = "LoraLoaderMixin is deprecated and this will be removed in a future version. Please use `StableDiffusionLoraLoaderMixin`, instead." deprecate("LoraLoaderMixin", "1.0.0", deprecation_message) - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 89ca9d39774b..8d029bf5d31c 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -91,6 +91,7 @@ _import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"] + _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] @@ -101,7 +102,6 @@ _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["transformers.transformer_wan_vace"] = ["WanVACETransformer3DModel"] - _import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] _import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"] @@ -183,6 +183,7 @@ HunyuanDiT2DModel, HunyuanVideoFramepackTransformer3DModel, HunyuanVideoTransformer3DModel, + Kandinsky5Transformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, Lumina2Transformer2DModel, @@ -201,7 +202,6 @@ TransformerTemporalModel, WanTransformer3DModel, WanVACETransformer3DModel, - Kandinsky5Transformer3DModel, ) from .unets import ( I2VGenXLUNet, diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 4b9911f9cb5d..6b80ea6c82a5 100755 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -27,6 +27,7 @@ from .transformer_hidream_image import HiDreamImageTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_hunyuan_video_framepack import HunyuanVideoFramepackTransformer3DModel + from .transformer_kandinsky import Kandinsky5Transformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel @@ -37,4 +38,3 @@ from .transformer_temporal import TransformerTemporalModel from .transformer_wan import WanTransformer3DModel from .transformer_wan_vace import WanVACETransformer3DModel - from .transformer_kandinsky import Kandinsky5Transformer3DModel diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index ad39a9bed63f..a338922583ca 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -12,48 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import math -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F -from torch import BoolTensor, IntTensor, Tensor, nn +from torch import Tensor from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import ( - USE_PEFT_BACKEND, - deprecate, logging, - scale_lora_layers, - unscale_lora_layers, ) -from ...utils.torch_utils import maybe_allow_in_graph -from ..attention import AttentionMixin, FeedForward, AttentionModuleMixin +from ..attention import AttentionMixin, AttentionModuleMixin +from ..attention_dispatch import _CAN_USE_FLEX_ATTN, dispatch_attention_fn from ..cache_utils import CacheMixin -from ..embeddings import ( - TimestepEmbedding, - get_1d_rotary_pos_embed, -) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin -from ..normalization import FP32LayerNorm -from ..attention_dispatch import dispatch_attention_fn, _CAN_USE_FLEX_ATTN - -logger = logging.get_logger(__name__) - - +logger = logging.get_logger(__name__) def get_freqs(dim, max_period=10000.0): - freqs = torch.exp( - -math.log(max_period) - * torch.arange(start=0, end=dim, dtype=torch.float32) - / dim - ) + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) return freqs @@ -147,7 +131,7 @@ def nablaT_v2( q = q.transpose(1, 2).contiguous() k = k.transpose(1, 2).contiguous() - + # Map estimation B, h, S, D = q.shape s1 = S // 64 @@ -167,9 +151,7 @@ def nablaT_v2( # BlockMask creation kv_nb = mask.sum(-1).to(torch.int32) kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) - return BlockMask.from_kv_blocks( - torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None - ) + return BlockMask.from_kv_blocks(torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None) class Kandinsky5TimeEmbeddings(nn.Module): @@ -188,7 +170,7 @@ def forward(self, time): args = torch.outer(time, self.freqs.to(device=time.device)) time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) - return time_embed + return time_embed class Kandinsky5TextEmbeddings(nn.Module): @@ -235,7 +217,7 @@ def __init__(self, dim, max_pos=1024, max_period=10000.0): self.max_pos = max_pos freq = get_freqs(dim // 2, max_period) pos = torch.arange(max_pos, dtype=freq.dtype) - self.register_buffer(f"args", torch.outer(pos, freq), persistent=False) + self.register_buffer("args", torch.outer(pos, freq), persistent=False) def forward(self, pos): args = self.args[pos] @@ -266,15 +248,9 @@ def forward(self, shape, pos, scale_factor=(1.0, 1.0, 1.0)): args = torch.cat( [ - args_t.view(1, duration, 1, 1, -1).repeat( - batch_size, 1, height, width, 1 - ), - args_h.view(1, 1, height, 1, -1).repeat( - batch_size, duration, 1, width, 1 - ), - args_w.view(1, 1, 1, width, -1).repeat( - batch_size, duration, height, 1, 1 - ), + args_t.view(1, duration, 1, 1, -1).repeat(batch_size, 1, height, width, 1), + args_h.view(1, 1, height, 1, -1).repeat(batch_size, duration, 1, width, 1), + args_w.view(1, 1, 1, width, -1).repeat(batch_size, duration, height, 1, 1), ], dim=-1, ) @@ -299,7 +275,6 @@ def forward(self, x): class Kandinsky5AttnProcessor: - _attention_backend = None _parallel_config = None @@ -307,7 +282,6 @@ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") - def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=None, sparse_params=None): # query, key, value = self.get_qkv(x) query = attn.to_query(hidden_states) @@ -320,7 +294,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, rotary_emb=N query = query.reshape(*shape, attn.num_heads, -1) key = key.reshape(*cond_shape, attn.num_heads, -1) value = value.reshape(*cond_shape, attn.num_heads, -1) - + else: key = attn.to_key(hidden_states) value = attn.to_value(hidden_states) @@ -352,10 +326,10 @@ def apply_rotary(x, rope): ) else: attn_mask = None - + hidden_states = dispatch_attention_fn( - query, - key, + query, + key, value, attn_mask=attn_mask, backend=self._attention_backend, @@ -368,11 +342,11 @@ def apply_rotary(x, rope): class Kandinsky5Attention(nn.Module, AttentionModuleMixin): - _default_processor_cls = Kandinsky5AttnProcessor _available_processors = [ Kandinsky5AttnProcessor, ] + def __init__(self, num_channels, head_dim, processor=None): super().__init__() assert num_channels % head_dim == 0 @@ -397,9 +371,6 @@ def forward( rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, **kwargs, ) -> torch.Tensor: - - import inspect - attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) quiet_attn_parameters = {} unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] @@ -409,9 +380,16 @@ def forward( ) kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} - return self.processor(self, hidden_states, encoder_hidden_states=encoder_hidden_states, sparse_params=sparse_params, rotary_emb=rotary_emb, **kwargs) + return self.processor( + self, + hidden_states, + encoder_hidden_states=encoder_hidden_states, + sparse_params=sparse_params, + rotary_emb=rotary_emb, + **kwargs, + ) + - class Kandinsky5FeedForward(nn.Module): def __init__(self, dim, ff_dim): super().__init__() @@ -429,16 +407,14 @@ def __init__(self, model_dim, time_dim, visual_dim, patch_size): self.patch_size = patch_size self.modulation = Kandinsky5Modulation(time_dim, model_dim, 2) self.norm = nn.LayerNorm(model_dim, elementwise_affine=False) - self.out_layer = nn.Linear( - model_dim, math.prod(patch_size) * visual_dim, bias=True - ) + self.out_layer = nn.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True) def forward(self, visual_embed, text_embed, time_embed): - shift, scale = torch.chunk( - self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) - - visual_embed = (self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None]).type_as(visual_embed) + shift, scale = torch.chunk(self.modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) + + visual_embed = ( + self.norm(visual_embed.float()) * (scale.float()[:, None, None] + 1.0) + shift.float()[:, None, None] + ).type_as(visual_embed) x = self.out_layer(visual_embed) @@ -474,9 +450,7 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim): self.feed_forward = Kandinsky5FeedForward(model_dim, ff_dim) def forward(self, x, time_embed, rope): - self_attn_params, ff_params = torch.chunk( - self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1 - ) + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed).unsqueeze(dim=1), 2, dim=-1) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) out = (self.self_attention_norm(x.float()) * (scale.float() + 1.0) + shift.float()).type_as(x) out = self.self_attention(out, rotary_emb=rope) @@ -510,17 +484,23 @@ def forward(self, visual_embed, text_embed, time_embed, rope, sparse_params): ) shift, scale, gate = torch.chunk(self_attn_params, 3, dim=-1) - visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.self_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.self_attention(visual_out, rotary_emb=rope, sparse_params=sparse_params) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(cross_attn_params, 3, dim=-1) - visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.cross_attention_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.cross_attention(visual_out, encoder_hidden_states=text_embed) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) shift, scale, gate = torch.chunk(ff_params, 3, dim=-1) - visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as(visual_embed) + visual_out = (self.feed_forward_norm(visual_embed.float()) * (scale.float() + 1.0) + shift.float()).type_as( + visual_embed + ) visual_out = self.feed_forward(visual_out) visual_embed = (visual_embed.float() + gate.float() * visual_out.float()).type_as(visual_embed) @@ -583,9 +563,7 @@ def __init__( self.time_embeddings = Kandinsky5TimeEmbeddings(model_dim, time_dim) self.text_embeddings = Kandinsky5TextEmbeddings(in_text_dim, model_dim) self.pooled_text_embeddings = Kandinsky5TextEmbeddings(in_text_dim2, time_dim) - self.visual_embeddings = Kandinsky5VisualEmbeddings( - visual_embed_dim, model_dim, patch_size - ) + self.visual_embeddings = Kandinsky5VisualEmbeddings(visual_embed_dim, model_dim, patch_size) # Initialize positional embeddings self.text_rope_embeddings = Kandinsky5RoPE1D(head_dim) @@ -593,10 +571,7 @@ def __init__( # Initialize transformer blocks self.text_transformer_blocks = nn.ModuleList( - [ - Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) - for _ in range(num_text_blocks) - ] + [Kandinsky5TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim) for _ in range(num_text_blocks)] ) self.visual_transformer_blocks = nn.ModuleList( @@ -607,9 +582,7 @@ def __init__( ) # Initialize output layer - self.out_layer = Kandinsky5OutLayer( - model_dim, time_dim, out_visual_dim, patch_size - ) + self.out_layer = Kandinsky5OutLayer(model_dim, time_dim, out_visual_dim, patch_size) self.gradient_checkpointing = False def forward( @@ -639,8 +612,7 @@ def forward( return_dict (`bool`, optional): Whether to return a dictionary Returns: - [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: - The output of the transformer + [`~models.transformer_2d.Transformer2DModelOutput`] or `torch.FloatTensor`: The output of the transformer """ x = hidden_states text_embed = encoder_hidden_states @@ -663,13 +635,9 @@ def forward( text_embed = text_transformer_block(text_embed, time_embed, text_rope) visual_shape = visual_embed.shape[:-1] - visual_rope = self.visual_rope_embeddings( - visual_shape, visual_rope_pos, scale_factor - ) + visual_rope = self.visual_rope_embeddings(visual_shape, visual_rope_pos, scale_factor) to_fractal = sparse_params["to_fractal"] if sparse_params is not None else False - visual_embed, visual_rope = fractal_flatten( - visual_embed, visual_rope, visual_shape, block_mask=to_fractal - ) + visual_embed, visual_rope = fractal_flatten(visual_embed, visual_rope, visual_shape, block_mask=to_fractal) for visual_transformer_block in self.visual_transformer_blocks: if torch.is_grad_enabled() and self.gradient_checkpointing: @@ -686,12 +654,10 @@ def forward( visual_embed, text_embed, time_embed, visual_rope, sparse_params ) - visual_embed = fractal_unflatten( - visual_embed, visual_shape, block_mask=to_fractal - ) + visual_embed = fractal_unflatten(visual_embed, visual_shape, block_mask=to_fractal) x = self.out_layer(visual_embed, text_embed, time_embed) if not return_dict: return x - return Transformer2DModelOutput(sample=x) \ No newline at end of file + return Transformer2DModelOutput(sample=x) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 201d92afb07c..c438caed571f 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -672,6 +672,7 @@ Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline, ) + from .kandinsky5 import Kandinsky5T2VPipeline from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, @@ -788,7 +789,6 @@ ) from .visualcloze import VisualClozeGenerationPipeline, VisualClozePipeline from .wan import WanImageToVideoPipeline, WanPipeline, WanVACEPipeline, WanVideoToVideoPipeline - from .kandinsky5 import Kandinsky5T2VPipeline from .wuerstchen import ( WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, diff --git a/src/diffusers/pipelines/kandinsky5/__init__.py b/src/diffusers/pipelines/kandinsky5/__init__.py index af8e12421740..a7975bdce926 100644 --- a/src/diffusers/pipelines/kandinsky5/__init__.py +++ b/src/diffusers/pipelines/kandinsky5/__init__.py @@ -23,7 +23,7 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["pipeline_kandinsky"] = ["Kandinsky5T2VPipeline"] - + if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index ff6b00d5fb26..3eb706f238ad 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -18,7 +18,7 @@ import regex as re import torch from torch.nn import functional as F -from transformers import Qwen2VLProcessor, Qwen2_5_VLForConditionalGeneration, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...loaders import KandinskyLoraLoaderMixin @@ -49,13 +49,13 @@ EXAMPLE_DOC_STRING = """ Examples: - + ```python >>> import torch >>> from diffusers import Kandinsky5T2VPipeline >>> from diffusers.utils import export_to_video - - >>> # Available models: + + >>> # Available models: >>> # ai-forever/Kandinsky-5.0-T2V-Lite-sft-5s-Diffusers >>> # ai-forever/Kandinsky-5.0-T2V-Lite-nocfg-5s-Diffusers >>> # ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers @@ -67,7 +67,7 @@ >>> prompt = "A cat and a dog baking a cake together in a kitchen." >>> negative_prompt = "Static, 2D cartoon, cartoon, 2d animation, paintings, images, worst quality, low quality, ugly, deformed, walking backwards" - + >>> output = pipe( ... prompt=prompt, ... negative_prompt=negative_prompt, @@ -77,7 +77,7 @@ ... num_inference_steps=50, ... guidance_scale=5.0, ... ).frames[0] - + >>> export_to_video(output, "output.mp4", fps=24, quality=9) ``` """ @@ -129,7 +129,13 @@ class Kandinsky5T2VPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + _callback_tensor_inputs = [ + "latents", + "prompt_embeds_qwen", + "prompt_embeds_clip", + "negative_prompt_embeds_qwen", + "negative_prompt_embeds_clip", + ] def __init__( self, @@ -152,40 +158,42 @@ def __init__( tokenizer_2=tokenizer_2, scheduler=scheduler, ) - - self.prompt_template = "\n".join(["<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", - "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", - "Describe the location of the video, main characters or objects and their action.", - "Describe the dynamism of the video and presented actions.", - "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", - "Describe the visual effects, postprocessing and transitions if they are presented in the video.", - "Pay attention to the order of key actions shown in the scene.<|im_end|>", - "<|im_start|>user\n{}<|im_end|>"]) + + self.prompt_template = "\n".join( + [ + "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.", + "Describe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.", + "Describe the location of the video, main characters or objects and their action.", + "Describe the dynamism of the video and presented actions.", + "Name the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or scren content.", + "Describe the visual effects, postprocessing and transitions if they are presented in the video.", + "Pay attention to the order of key actions shown in the scene.<|im_end|>", + "<|im_start|>user\n{}<|im_end|>", + ] + ) self.prompt_template_encode_start_idx = 129 self.vae_scale_factor_temporal = vae.config.temporal_compression_ratio self.vae_scale_factor_spatial = vae.config.spatial_compression_ratio self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - + @staticmethod - def fast_sta_nabla( - T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda" - ) -> torch.Tensor: + def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> torch.Tensor: """ Create a sparse temporal attention (STA) mask for efficient video generation. - - This method generates a mask that limits attention to nearby frames and spatial positions, - reducing computational complexity for video generation. - + + This method generates a mask that limits attention to nearby frames and spatial positions, reducing + computational complexity for video generation. + Args: T (int): Number of temporal frames H (int): Height in latent space - W (int): Width in latent space + W (int): Width in latent space wT (int): Temporal attention window size wH (int): Height attention window size wW (int): Width attention window size device (str): Device to create tensor on - + Returns: torch.Tensor: Sparse attention mask of shape (T*H*W, T*H*W) """ @@ -200,30 +208,21 @@ def fast_sta_nabla( sta_t = sta_t <= wT // 2 sta_h = sta_h <= wH // 2 sta_w = sta_w <= wW // 2 - sta_hw = ( - (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) - .reshape(H, H, W, W) - .transpose(1, 2) - .flatten() - ) - sta = ( - (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) - .reshape(T, T, H * W, H * W) - .transpose(1, 2) - ) + sta_hw = (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)).reshape(H, H, W, W).transpose(1, 2).flatten() + sta = (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)).reshape(T, T, H * W, H * W).transpose(1, 2) return sta.reshape(T * H * W, T * H * W) - + def get_sparse_params(self, sample, device): """ Generate sparse attention parameters for the transformer based on sample dimensions. - - This method computes the sparse attention configuration needed for efficient - video processing in the transformer model. - + + This method computes the sparse attention configuration needed for efficient video processing in the + transformer model. + Args: sample (torch.Tensor): Input sample tensor device (torch.device): Device to place tensors on - + Returns: Dict: Dictionary containing sparse attention parameters """ @@ -236,13 +235,15 @@ def get_sparse_params(self, sample, device): ) if self.transformer.config.attention_type == "nabla": sta_mask = self.fast_sta_nabla( - T, H // 8, W // 8, - self.transformer.config.attention_wT, - self.transformer.config.attention_wH, - self.transformer.config.attention_wW, - device=device + T, + H // 8, + W // 8, + self.transformer.config.attention_wT, + self.transformer.config.attention_wH, + self.transformer.config.attention_wW, + device=device, ) - + sparse_params = { "sta_mask": sta_mask.unsqueeze_(0).unsqueeze_(0), "attention_type": self.transformer.config.attention_type, @@ -269,17 +270,17 @@ def _encode_prompt_qwen( ): """ Encode prompt using Qwen2.5-VL text encoder. - - This method processes the input prompt through the Qwen2.5-VL model to generate - text embeddings suitable for video generation. - + + This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for + video generation. + Args: prompt (Union[str, List[str]]): Input prompt or list of prompts device (torch.device): Device to run encoding on num_videos_per_prompt (int): Number of videos to generate per prompt max_sequence_length (int): Maximum sequence length for tokenization dtype (torch.dtype): Data type for embeddings - + Returns: Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths """ @@ -287,7 +288,7 @@ def _encode_prompt_qwen( dtype = dtype or self.text_encoder.dtype full_texts = [self.prompt_template.format(p) for p in prompt] - + inputs = self.tokenizer( text=full_texts, images=None, @@ -302,13 +303,12 @@ def _encode_prompt_qwen( input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, - )["hidden_states"][-1][:, self.prompt_template_encode_start_idx:] + )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] - - attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx:] + attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) - + return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( @@ -319,16 +319,16 @@ def _encode_prompt_clip( ): """ Encode prompt using CLIP text encoder. - - This method processes the input prompt through the CLIP model to generate - pooled embeddings that capture semantic information. - + + This method processes the input prompt through the CLIP model to generate pooled embeddings that capture + semantic information. + Args: prompt (Union[str, List[str]]): Input prompt or list of prompts device (torch.device): Device to run encoding on num_videos_per_prompt (int): Number of videos to generate per prompt dtype (torch.dtype): Data type for embeddings - + Returns: torch.Tensor: Pooled text embeddings from CLIP """ @@ -346,69 +346,8 @@ def _encode_prompt_clip( pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] - return pooled_embed.to(dtype) -# def encode_prompt( -# self, -# prompt: Union[str, List[str]], -# num_videos_per_prompt: int = 1, -# max_sequence_length: int = 512, -# device: Optional[torch.device] = None, -# dtype: Optional[torch.dtype] = None, -# ): -# r""" -# Encodes a single prompt (positive or negative) into text encoder hidden states. - -# This method combines embeddings from both Qwen2.5-VL and CLIP text encoders -# to create comprehensive text representations for video generation. - -# Args: -# prompt (`str` or `List[str]`): -# Prompt to be encoded. -# num_videos_per_prompt (`int`, *optional*, defaults to 1): -# Number of videos to generate per prompt. -# max_sequence_length (`int`, *optional*, defaults to 512): -# Maximum sequence length for text encoding. -# device (`torch.device`, *optional*): -# Torch device. -# dtype (`torch.dtype`, *optional*): -# Torch dtype. - -# Returns: -# Tuple[Dict[str, torch.Tensor], torch.Tensor]: -# - A dict with keys `"text_embeds"` (from Qwen) and `"pooled_embed"` (from CLIP) -# - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings -# """ -# device = device or self._execution_device -# dtype = dtype or self.text_encoder.dtype - -# batch_size = len(prompt) - -# prompt = [prompt_clean(p) for p in prompt] - -# # Encode with Qwen2.5-VL -# prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( -# prompt=prompt, -# device=device, -# max_sequence_length=max_sequence_length, -# dtype=dtype, -# ) - -# # Encode with CLIP -# prompt_embeds_clip = self._encode_prompt_clip( -# prompt=prompt, -# device=device, -# dtype=dtype, -# ) -# prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) -# prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1) - -# prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) -# prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) - -# return prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens - def encode_prompt( self, prompt: Union[str, List[str]], @@ -420,8 +359,8 @@ def encode_prompt( r""" Encodes a single prompt (positive or negative) into text encoder hidden states. - This method combines embeddings from both Qwen2.5-VL and CLIP text encoders - to create comprehensive text representations for video generation. + This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text + representations for video generation. Args: prompt (`str` or `List[str]`): @@ -439,7 +378,8 @@ def encode_prompt( Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - Qwen text embeddings of shape (batch_size * num_videos_per_prompt, sequence_length, embedding_dim) - CLIP pooled embeddings of shape (batch_size * num_videos_per_prompt, clip_embedding_dim) - - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_videos_per_prompt + 1,) + - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * + num_videos_per_prompt + 1,) """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype @@ -467,12 +407,18 @@ def encode_prompt( # Repeat embeddings for num_videos_per_prompt # Qwen embeddings: repeat sequence for each video, then reshape - prompt_embeds_qwen = prompt_embeds_qwen.repeat(1, num_videos_per_prompt, 1) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] + prompt_embeds_qwen = prompt_embeds_qwen.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, seq_len * num_videos_per_prompt, embed_dim] # Reshape to [batch_size * num_videos_per_prompt, seq_len, embed_dim] - prompt_embeds_qwen = prompt_embeds_qwen.view(batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1]) + prompt_embeds_qwen = prompt_embeds_qwen.view( + batch_size * num_videos_per_prompt, -1, prompt_embeds_qwen.shape[-1] + ) # CLIP embeddings: repeat for each video - prompt_embeds_clip = prompt_embeds_clip.repeat(1, num_videos_per_prompt, 1) # [batch_size, num_videos_per_prompt, clip_embed_dim] + prompt_embeds_clip = prompt_embeds_clip.repeat( + 1, num_videos_per_prompt, 1 + ) # [batch_size, num_videos_per_prompt, clip_embed_dim] # Reshape to [batch_size * num_videos_per_prompt, clip_embed_dim] prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_videos_per_prompt, -1) @@ -480,11 +426,15 @@ def encode_prompt( # Original cu_seqlens: [0, len1, len1+len2, ...] # Need to repeat the differences and reconstruct for repeated prompts # Original differences (lengths) for each prompt in the batch - original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] + original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] # Repeat the lengths for num_videos_per_prompt - repeated_lengths = original_lengths.repeat_interleave(num_videos_per_prompt) # [len1, len1, ..., len2, len2, ...] + repeated_lengths = original_lengths.repeat_interleave( + num_videos_per_prompt + ) # [len1, len1, ..., len2, len2, ...] # Reconstruct the cumulative lengths - repeated_cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)]) + repeated_cu_seqlens = torch.cat( + [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] + ) return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens @@ -509,7 +459,7 @@ def check_inputs( prompt: Input prompt negative_prompt: Negative prompt for guidance height: Video height - width: Video width + width: Video width prompt_embeds_qwen: Pre-computed Qwen prompt embeddings prompt_embeds_clip: Pre-computed CLIP prompt embeddings negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings @@ -535,16 +485,24 @@ def check_inputs( if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: raise ValueError( - f"If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " - f"all three must be provided." + "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " + "all three must be provided." ) # Check for consistency within negative prompt embeddings and sequence lengths - if negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None: - if negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None: + if ( + negative_prompt_embeds_qwen is not None + or negative_prompt_embeds_clip is not None + or negative_prompt_cu_seqlens is not None + ): + if ( + negative_prompt_embeds_qwen is None + or negative_prompt_embeds_clip is None + or negative_prompt_cu_seqlens is None + ): raise ValueError( - f"If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " - f"all three must be provided." + "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " + "all three must be provided." ) # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) @@ -575,21 +533,20 @@ def prepare_latents( ) -> torch.Tensor: """ Prepare initial latent variables for video generation. - - This method creates random noise latents or uses provided latents as starting point - for the denoising process. - + + This method creates random noise latents or uses provided latents as starting point for the denoising process. + Args: batch_size (int): Number of videos to generate num_channels_latents (int): Number of channels in latent space height (int): Height of generated video - width (int): Width of generated video + width (int): Width of generated video num_frames (int): Number of frames in video dtype (torch.dtype): Data type for latents device (torch.device): Device to create latents on generator (torch.Generator): Random number generator latents (torch.Tensor): Pre-existing latents to use - + Returns: torch.Tensor: Prepared latent tensor """ @@ -611,14 +568,20 @@ def prepare_latents( ) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - + if self.transformer.visual_cond: # For visual conditioning, concatenate with zeros and mask visual_cond = torch.zeros_like(latents) visual_cond_mask = torch.zeros( - [batch_size, num_latent_frames, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, 1], - dtype=latents.dtype, - device=latents.device + [ + batch_size, + num_latent_frames, + int(height) // self.vae_scale_factor_spatial, + int(width) // self.vae_scale_factor_spatial, + 1, + ], + dtype=latents.dtype, + device=latents.device, ) latents = torch.cat([latents, visual_cond, visual_cond_mask], dim=-1) @@ -715,13 +678,13 @@ def __call__( The list of tensor inputs for the `callback_on_step_end` function. max_sequence_length (`int`, defaults to `512`): The maximum sequence length for text encoding. - + Examples: - + Returns: [`~KandinskyPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned where - the first element is a list with the generated images. + If `return_dict` is `True`, [`KandinskyPipelineOutput`] is returned, otherwise a `tuple` is returned + where the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs @@ -761,17 +724,16 @@ def __call__( elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: - batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, (list, tuple)) else prompt_embeds.shape[0] - - # 3. Encode input prompt - if prompt_embeds_qwen is None: - prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( - prompt=prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + batch_size = prompt_embeds_qwen.shape[0] + # 3. Encode input prompt + if prompt_embeds_qwen is None: + prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( + prompt=prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) if self.do_classifier_free_guidance: if negative_prompt is None: @@ -785,12 +747,12 @@ def __call__( ) if negative_prompt_embeds_qwen is None: - negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( - prompt=negative_prompt, - max_sequence_length=max_sequence_length, - device=device, - dtype=dtype, - ) + negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_cu_seqlens = self.encode_prompt( + prompt=negative_prompt, + max_sequence_length=max_sequence_length, + device=device, + dtype=dtype, + ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -817,15 +779,15 @@ def __call__( torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] - + text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) - + negative_text_rope_pos = ( torch.arange(negative_cu_seqlens.diff().max().item(), device=device) if negative_cu_seqlens is not None else None ) - + # 7. Sparse Params for efficient attention sparse_params = self.get_sparse_params(latents, device) @@ -839,8 +801,8 @@ def __call__( continue timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) - - # Predict noise residual + + # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_qwen.to(dtype), @@ -848,12 +810,12 @@ def __call__( timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, - scale_factor=(1, 2, 2), + scale_factor=(1, 2, 2), sparse_params=sparse_params, - return_dict=True + return_dict=True, ).sample - if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: + if self.do_classifier_free_guidance and negative_prompt_embeds_qwen is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), @@ -863,12 +825,10 @@ def __call__( text_rope_pos=negative_text_rope_pos, scale_factor=(1, 2, 2), sparse_params=sparse_params, - return_dict=True + return_dict=True, ).sample - pred_velocity = uncond_pred_velocity + guidance_scale * ( - pred_velocity - uncond_pred_velocity - ) + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) # Compute previous sample using the scheduler latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( pred_velocity, t, latents[:, :, :, :, :num_channels_latents], return_dict=False @@ -881,8 +841,14 @@ def __call__( callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) - prompt_embeds_dict = callback_outputs.pop("prompt_embeds", prompt_embeds_dict) - negative_prompt_embeds_dict = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds_dict) + prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) + prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) + negative_prompt_embeds_qwen = callback_outputs.pop( + "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen + ) + negative_prompt_embeds_clip = callback_outputs.pop( + "negative_prompt_embeds_clip", negative_prompt_embeds_clip + ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() @@ -907,13 +873,13 @@ def __call__( ) video = video.permute(0, 1, 5, 2, 3, 4) # [batch, num_videos, channels, frames, height, width] video = video.reshape( - batch_size * num_videos_per_prompt, - num_channels_latents, - (num_frames - 1) // self.vae_scale_factor_temporal + 1, - height // self.vae_scale_factor_spatial, - width // self.vae_scale_factor_spatial + batch_size * num_videos_per_prompt, + num_channels_latents, + (num_frames - 1) // self.vae_scale_factor_temporal + 1, + height // self.vae_scale_factor_spatial, + width // self.vae_scale_factor_spatial, ) - + # Normalize and decode through VAE video = video / self.vae.config.scaling_factor video = self.vae.decode(video).sample diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6e7d22797902..5d62709c28fd 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -918,6 +918,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Kandinsky5Transformer3DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LatteTransformer3DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 9ed625045261..3244ef12ef87 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1247,6 +1247,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Kandinsky5T2VPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class KandinskyCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 1746f6d426dd37541dec98a9c338e0465ced3ead Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 17 Oct 2025 17:22:58 -1000 Subject: [PATCH 67/70] Update src/diffusers/models/transformers/transformer_kandinsky.py Co-authored-by: Charles --- src/diffusers/models/transformers/transformer_kandinsky.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index a338922583ca..86032f5462d1 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -518,7 +518,10 @@ class Kandinsky5Transformer3DModel( """ A 3D Diffusion Transformer model for video-like data. """ - +_repeated_blocks = [ + "Kandinsky5TransformerEncoderBlock", + "Kandinsky5TransformerDecoderBlock", +] _supports_gradient_checkpointing = True @register_to_config From 5bb1657f9efb11d50d3c19cbe367e8086e15623a Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:25:17 +0200 Subject: [PATCH 68/70] more --- .../models/transformers/transformer_kandinsky.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_kandinsky.py b/src/diffusers/models/transformers/transformer_kandinsky.py index 86032f5462d1..d4ba92acaf6e 100644 --- a/src/diffusers/models/transformers/transformer_kandinsky.py +++ b/src/diffusers/models/transformers/transformer_kandinsky.py @@ -518,10 +518,11 @@ class Kandinsky5Transformer3DModel( """ A 3D Diffusion Transformer model for video-like data. """ -_repeated_blocks = [ - "Kandinsky5TransformerEncoderBlock", - "Kandinsky5TransformerDecoderBlock", -] + + _repeated_blocks = [ + "Kandinsky5TransformerEncoderBlock", + "Kandinsky5TransformerDecoderBlock", + ] _supports_gradient_checkpointing = True @register_to_config From a26300f7335613ae8eaf1ee082038de63dbddfa7 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Fri, 17 Oct 2025 17:32:19 -1000 Subject: [PATCH 69/70] Apply suggestions from code review --- src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 3eb706f238ad..a1122a82565e 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -618,7 +618,6 @@ def __call__( num_frames: int = 121, num_inference_steps: int = 50, guidance_scale: float = 5.0, - scheduler_scale: float = 5.0, num_videos_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, @@ -656,8 +655,6 @@ def __call__( The number of denoising steps. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in classifier-free guidance. - scheduler_scale (`float`, defaults to `10.0`): - Scale factor for the custom flow matching scheduler. num_videos_per_prompt (`int`, *optional*, defaults to 1): The number of videos to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): From ecbe522399e61b61b2ff26658bd5090d849bb190 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Sat, 18 Oct 2025 05:37:42 +0200 Subject: [PATCH 70/70] add lora loader doc --- docs/source/en/api/loaders/lora.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/en/api/loaders/lora.md b/docs/source/en/api/loaders/lora.md index b1d1ffb63423..8e0326e0c334 100644 --- a/docs/source/en/api/loaders/lora.md +++ b/docs/source/en/api/loaders/lora.md @@ -107,6 +107,9 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi [[autodoc]] loaders.lora_pipeline.QwenImageLoraLoaderMixin +## KandinskyLoraLoaderMixin +[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin + ## LoraBaseMixin [[autodoc]] loaders.lora_base.LoraBaseMixin \ No newline at end of file