From 43825c46c598d620e0bc117d2a64245a083b3080 Mon Sep 17 00:00:00 2001 From: csuhan Date: Fri, 24 Jan 2025 20:36:25 +0800 Subject: [PATCH 01/25] Add support for lumina2 --- src/diffusers/__init__.py | 4 + src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention_processor.py | 97 ++ src/diffusers/models/embeddings.py | 61 ++ src/diffusers/models/normalization.py | 28 +- .../transformers/transformer_lumina2.py | 415 +++++++++ src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/auto_pipeline.py | 2 + .../pipelines/lumina/pipeline_lumina.py | 2 +- src/diffusers/pipelines/lumina2/__init__.py | 48 + .../pipelines/lumina2/pipeline_lumina2.py | 825 ++++++++++++++++++ 11 files changed, 1474 insertions(+), 12 deletions(-) create mode 100644 src/diffusers/models/transformers/transformer_lumina2.py create mode 100644 src/diffusers/pipelines/lumina2/__init__.py create mode 100644 src/diffusers/pipelines/lumina2/pipeline_lumina2.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b1801fbb2b4b..c8438d8e9c93 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -110,6 +110,7 @@ "LatteTransformer3DModel", "LTXVideoTransformer3DModel", "LuminaNextDiT2DModel", + "Lumina2Transformer2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", @@ -329,6 +330,7 @@ "LTXImageToVideoPipeline", "LTXPipeline", "LuminaText2ImgPipeline", + "Lumina2Text2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", "MochiPipeline", @@ -622,6 +624,7 @@ LatteTransformer3DModel, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, + Lumina2Transformer2DModel, MochiTransformer3DModel, ModelMixin, MotionAdapter, @@ -820,6 +823,7 @@ LTXImageToVideoPipeline, LTXPipeline, LuminaText2ImgPipeline, + Lumina2Text2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, MochiPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index e3f291ce2dc7..04618d9a9baf 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -60,6 +60,7 @@ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"] _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] + _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"] @@ -139,6 +140,7 @@ LatteTransformer3DModel, LTXVideoTransformer3DModel, LuminaNextDiT2DModel, + Lumina2Transformer2DModel, MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 30e160dd2408..e2f6ae5b56b1 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -4192,6 +4192,102 @@ def __call__( return hidden_states +class Lumina2AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + query_rotary_emb: Optional[torch.Tensor] = None, + key_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + from embeddings import apply_rotary_emb + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if query_rotary_emb is not None: + query = apply_rotary_emb(query, query_rotary_emb, use_real=False) + if key_rotary_emb is not None: + key = apply_rotary_emb(key, key_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + + return hidden_states + + class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -6183,6 +6279,7 @@ def __call__( PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, + Lumina2AttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index bd3237c24c1c..3db870e07056 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -569,6 +569,35 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) +class Lumina2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512)): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) + + def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: + freqs_cis = [] + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed( + d, + e, + theta=self.theta, + freqs_dtype=torch.float64, + ) + freqs_cis.append(emb) + return freqs_cis + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + result = [] + for i in range(len(self.axes_dim)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + class LuminaPatchEmbed(nn.Module): """ 2D Image to Patch Embedding with support for Lumina-T2X @@ -1766,6 +1795,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde return conditioning +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): + super().__init__() + + from normalization import RMSNorm + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)) + + self.caption_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_embed = self.caption_embedder(caption_feat) + + return time_embed, caption_embed + + class LuminaCombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): super().__init__() diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 7db4d3d17d2f..68c90f52c2c2 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -211,15 +211,17 @@ class LuminaRMSNormZero(nn.Module): embedding_dim (`int`): The size of each embedding vector. """ - def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): + def __init__(self, embedding_dim: int, norm_eps: float, modulation: bool): super().__init__() - self.silu = nn.SiLU() - self.linear = nn.Linear( - min(embedding_dim, 1024), - 4 * embedding_dim, - bias=True, - ) - self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) + self.modulation = modulation + if modulation: + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) + self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( self, @@ -227,9 +229,13 @@ def forward( emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) - emb = self.linear(self.silu(emb)) - scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) + if self.modulation: + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + else: + gate_msa, scale_mlp, gate_mlp = None, None, None + x = self.norm(x) return x, gate_msa, scale_mlp, gate_mlp diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py new file mode 100644 index 000000000000..55087abcfccb --- /dev/null +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -0,0 +1,415 @@ +# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import logging +from ..attention import LuminaFeedForward +from ..attention_processor import Attention, Lumina2AttnProcessor2_0 +from ..embeddings import ( + Lumina2CombinedTimestepCaptionEmbedding, + Lumina2PosEmbed, +) +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class Lumina2TransformerBlock(nn.Module): + """ + A Lumina2TransformerBlock for Lumina2Transformer2DModel. + + Parameters: + dim (`int`): Embedding dimension of the input features. + num_attention_heads (`int`): Number of attention heads. + num_kv_heads (`int`): + Number of attention heads in key and value features (if using GQA), or set to None for the same as query. + multiple_of (`int`): The number of multiple of ffn layer. + ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. + norm_eps (`float`): The eps for norm layer. + qk_norm (`bool`): normalization for query and key. + cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. + modulation (`bool`): Whether to use modulation. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + modulation: bool = True, + ) -> None: + super().__init__() + self.head_dim = dim // num_attention_heads + self.modulation = modulation + + # Self-attention + self.attn = Attention( + query_dim=dim, + cross_attention_dim=None, + dim_head=dim // num_attention_heads, + qk_norm="rms_norm", + heads=num_attention_heads, + kv_heads=num_kv_heads, + eps=1e-5, + bias=False, + out_bias=False, + processor=Lumina2AttnProcessor2_0(), + ) + + self.feed_forward = LuminaFeedForward( + dim=dim, + inner_dim=4 * dim, + multiple_of=multiple_of, + ffn_dim_multiplier=ffn_dim_multiplier, + ) + + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + modulation=modulation, + ) + self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) + + self.norm2 = RMSNorm(dim, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + image_rotary_emb: torch.Tensor, + temb: Optional[torch.Tensor] = None, + ): + """ + Perform a forward pass through the Lumina2TransformerBlock. + + Parameters: + hidden_states (`torch.Tensor`): The input of hidden_states for Lumina2TransformerBlock. + attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. + encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. + temb (`torch.Tensor`): Timestep embedding with text prompt embedding. + """ + + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=image_rotary_emb, + key_rotary_emb=image_rotary_emb, + ) + + if self.modulation: + hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) + hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) + else: + hidden_states = hidden_states + self.norm2(attn_output) + mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) + hidden_states = hidden_states + self.ffn_norm2(mlp_output) + + return hidden_states + + +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin): + """ + LuminaNextDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters: + sample_size (`int`): The width of the latent images. This is fixed during training since + it is used to learn a number of position embeddings. + patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2): + The size of each patch in the image. This parameter defines the resolution of patches fed into the model. + in_channels (`int`, *optional*, defaults to 4): + The number of input channels for the model. Typically, this matches the number of channels in the input + images. + hidden_size (`int`, *optional*, defaults to 4096): + The dimensionality of the hidden layers in the model. This parameter determines the width of the model's + hidden representations. + num_layers (`int`, *optional*, default to 32): + The number of layers in the model. This defines the depth of the neural network. + num_attention_heads (`int`, *optional*, defaults to 32): + The number of attention heads in each attention layer. This parameter specifies how many separate attention + mechanisms are used. + num_kv_heads (`int`, *optional*, defaults to 8): + The number of key-value heads in the attention mechanism, if different from the number of attention heads. + If None, it defaults to num_attention_heads. + multiple_of (`int`, *optional*, defaults to 256): + A factor that the hidden size should be a multiple of. This can help optimize certain hardware + configurations. + ffn_dim_multiplier (`float`, *optional*): + A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on + the model configuration. + norm_eps (`float`, *optional*, defaults to 1e-5): + A small value added to the denominator for numerical stability in normalization layers. + scaling_factor (`float`, *optional*, defaults to 1.0): + A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the + overall scale of the model's operations. + """ + + @register_to_config + def __init__( + self, + sample_size: int = 128, + patch_size: Optional[int] = 2, + in_channels: Optional[int] = 16, + out_channels: Optional[int] = None, + hidden_size: Optional[int] = 2304, + num_layers: Optional[int] = 26, + num_refiner_layers: Optional[int] = 2, + num_attention_heads: Optional[int] = 24, + num_kv_heads: Optional[int] = 8, + multiple_of: Optional[int] = 256, + ffn_dim_multiplier: Optional[float] = None, + norm_eps: Optional[float] = 1e-5, + scaling_factor: Optional[float] = 1.0, + axes_dim_rope: Optional[tuple[int, int, int]] = (32, 32, 32), + axes_lens: Optional[tuple[int, int, int]] = (300, 512, 512), + cap_feat_dim: Optional[int] = 1024, + ) -> None: + super().__init__() + self.sample_size = sample_size + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels or in_channels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.head_dim = hidden_size // num_attention_heads + self.scaling_factor = scaling_factor + + self.rope_embedder = Lumina2PosEmbed( + theta=10000, + axes_dim=axes_dim_rope, + axes_lens=axes_lens, + ) + self.x_embedder = nn.Linear( + in_features=patch_size * patch_size * in_channels, + out_features=hidden_size, + bias=True, + ) + + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( + hidden_size=hidden_size, + cap_feat_dim=cap_feat_dim, + norm_eps=norm_eps, + ) + + self.noise_refiner = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.context_refiner = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=False, + ) + for _ in range(num_refiner_layers) + ] + ) + + self.layers = nn.ModuleList( + [ + Lumina2TransformerBlock( + hidden_size, + num_attention_heads, + num_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + modulation=True, + ) + for _ in range(num_layers) + ] + ) + self.norm_out = LuminaLayerNormContinuous( + embedding_dim=hidden_size, + conditioning_embedding_dim=min(hidden_size, 1024), + elementwise_affine=False, + eps=1e-6, + bias=True, + out_dim=patch_size * patch_size * self.out_channels, + ) + # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) + + assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + def patchify_and_embed( + self, + x: list[torch.Tensor] | torch.Tensor, + cap_feats: torch.Tensor, + cap_mask: torch.Tensor, + t: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], list[int], torch.Tensor]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + + max_seq_len = max( + (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) + ) + max_cap_len = max(l_effective_cap_len) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // pH, W // pW + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len:cap_len+img_len, 0] = cap_len + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + position_ids[i, cap_len:cap_len+img_len, 1] = row_ids + position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + + freqs_cis = self.rope_embedder(position_ids) + + cap_freqs_cis_shape = list(freqs_cis.shape) + cap_freqs_cis_shape[1] = cap_feats.shape[1] + cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + flat_x = [] + for i in range(bsz): + img = x[i] + C, H, W = img.size() + img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) + flat_x.append(img) + x = flat_x + padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + for i in range(bsz): + padded_img_embed[i, :l_effective_img_len[i]] = x[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + padded_img_embed = self.x_embedder(padded_img_embed) + for layer in self.noise_refiner: + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + + mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + padded_full_embed = torch.zeros(bsz, max_seq_len, self.hidden_size, device=device, dtype=x[0].dtype) + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + mask[i, :cap_len+img_len] = True + padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] + padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + + return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis + + + def forward( + self, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + return_dict=True, + ) -> torch.Tensor: + """ + Forward pass of LuminaNextDiT. + + Parameters: + hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). + timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). + encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). + encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). + """ + + temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) + hidden_states, mask, img_size, cap_size, image_rotary_emb = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb) + image_rotary_emb = image_rotary_emb.to(hidden_states.device) + + for layer in self.layers: + hidden_states = layer( + hidden_states, + mask, + image_rotary_emb, + temb=temb, + ) + + hidden_states = self.norm_out(hidden_states, temb) + + # unpatchify + height_tokens = width_tokens = self.patch_size + output = [] + for i in range(len(img_size)): + height, width = img_size[i] + begin = cap_size[i] + end = begin + (height // height_tokens) * (width // width_tokens) + output.append( + hidden_states[i][begin:end] + .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels) + .permute(4, 0, 2, 1, 3) + .flatten(3, 4) + .flatten(1, 2) + ) + output = torch.stack(output, dim=0) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 5829cf495dcc..e82d92e70e8b 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -256,6 +256,7 @@ _import_structure["latte"] = ["LattePipeline"] _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"] _import_structure["lumina"] = ["LuminaText2ImgPipeline"] + _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( [ "MarigoldDepthPipeline", @@ -596,6 +597,7 @@ ) from .ltx import LTXImageToVideoPipeline, LTXPipeline from .lumina import LuminaText2ImgPipeline + from .lumina2 import Lumina2Text2ImgPipeline from .marigold import ( MarigoldDepthPipeline, MarigoldNormalsPipeline, diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py index a19329431b05..6066836e7a05 100644 --- a/src/diffusers/pipelines/auto_pipeline.py +++ b/src/diffusers/pipelines/auto_pipeline.py @@ -65,6 +65,7 @@ from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .lumina import LuminaText2ImgPipeline +from .lumina2 import Lumina2Text2ImgPipeline from .pag import ( HunyuanDiTPAGPipeline, PixArtSigmaPAGPipeline, @@ -135,6 +136,7 @@ ("flux-control", FluxControlPipeline), ("flux-controlnet", FluxControlNetPipeline), ("lumina", LuminaText2ImgPipeline), + ("lumina2", Lumina2Text2ImgPipeline), ("cogview3", CogView3PlusPipeline), ] ) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index 5b37e9a503a8..ac94312bf786 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,7 +17,7 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Callable, Dict import torch from transformers import AutoModel, AutoTokenizer diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py new file mode 100644 index 000000000000..0e51a768a785 --- /dev/null +++ b/src/diffusers/pipelines/lumina2/__init__.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + else: + from .pipeline_lumina2 import Lumina2Text2ImgPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py new file mode 100644 index 000000000000..ce4eb36702fc --- /dev/null +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -0,0 +1,825 @@ +# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import math +import numpy as np +import re +import urllib.parse as ul +from typing import List, Optional, Tuple, Union, Callable, Dict + +import torch +from transformers import AutoModel, AutoTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import AutoencoderKL +from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import ( + BACKENDS_MAPPING, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Lumina2Text2ImgPipeline + + >>> pipe = Lumina2Text2ImgPipeline.from_pretrained( + ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 + ... ) + >>> # Enable memory optimizations. + >>> pipe.enable_model_cpu_offload() + + >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures" + >>> image = pipe(prompt).images[0] + ``` +""" + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class Lumina2Text2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Lumina-T2I. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`AutoModel`]): + Frozen text-encoder. Lumina-T2I uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant. + tokenizer (`AutoModel`): + Tokenizer of class + [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = [] + _callback_tensor_inputs = ["latents", "prompt_embeds"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + transformer: Lumina2Transformer2DModel, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKL, + text_encoder: AutoModel, + tokenizer: AutoTokenizer, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 8 + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + self.max_sequence_length = 256 + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) + self.default_image_size = self.default_sample_size * self.vae_scale_factor + self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts." + + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + max_length: Optional[int] = None, + ): + device = device or self._execution_device + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because Gemma can only handle sequences up to" + f" {self.max_sequence_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = self.text_encoder( + text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True + ) + prompt_embeds = prompt_embeds.hidden_states[-2] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, prompt_attention_mask + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: Union[str, List[str]] = None, + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + system_prompt: Optional[str] = None, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + Lumina-T2I, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. + max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + """ + if device is None: + device = self._execution_device + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if system_prompt is None: + system_prompt = self.system_prompt + prompt = [system_prompt + ' ' + p for p in prompt] + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + ) + + # Get negative embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt if negative_prompt is not None else "" + + # Normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + # Padding negative prompt to the same length with prompt + prompt_max_length = prompt_embeds.shape[1] + negative_text_inputs = self.tokenizer( + negative_prompt, + padding="max_length", + max_length=prompt_max_length, + truncation=True, + return_tensors="pt", + ) + negative_text_input_ids = negative_text_inputs.input_ids.to(device) + negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) + # Get the negative prompt embeddings + negative_prompt_embeds = self.text_encoder( + negative_text_input_ids, + attention_mask=negative_prompt_attention_mask, + output_hidden_states=True, + ) + + negative_dtype = self.text_encoder.dtype + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + _, seq_len, _ = negative_prompt_embeds.shape + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + callback_on_step_end_tensor_inputs=None, + max_sequence_length=None, + ): + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.") + + if negative_prompt_embeds is not None and negative_prompt_attention_mask is None: + raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.") + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + if max_sequence_length is not None and max_sequence_length > 512: + raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + num_inference_steps: int = 30, + guidance_scale: float = 4.0, + negative_prompt: Union[str, List[str]] = None, + sigmas: List[float] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + max_sequence_length: int = 256, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + system_prompt: Optional[str] = None, + cfg_trunc_ratio: float = 1.0, + cfg_normalization: bool = True, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 30): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + max_sequence_length (`int` defaults to 120): + Maximum sequence length to use with the `prompt`. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + system_prompt (`str`, *optional*): + The system prompt to use for the image generation. + cfg_trunc_ratio (`float`, *optional*, defaults to 1.0): + The ratio of the timestep interval to apply normalization-based guidance scale. + cfg_normalization (`bool`, *optional*, defaults to True): + Whether to apply normalization-based guidance scale. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + scaling_factor = math.sqrt(width * height / self.default_image_size**2) + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + self.tokenizer.padding_side = "right" + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + max_sequence_length=max_sequence_length, + system_prompt=system_prompt, + ) + if do_classifier_free_guidance: + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) + prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) + + # 4. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + image_seq_len = latents.shape[1] + mu = calculate_shift( + image_seq_len, + self.scheduler.config.get("base_image_seq_len", 256), + self.scheduler.config.get("max_image_seq_len", 4096), + self.scheduler.config.get("base_shift", 0.5), + self.scheduler.config.get("max_shift", 1.16), + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 6. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance and 1 - t.item() / self.scheduler.config.num_train_timesteps < cfg_trunc_ratio else latents + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latent_model_input.shape[0]) + + # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image + current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + return_dict=False, + )[0] + + # perform normalization-based guidance scale on a truncated timestep interval + if do_classifier_free_guidance and current_timestep[0] < cfg_trunc_ratio: + noise_pred_cond, noise_pred_uncond = torch.split( + noise_pred, len(noise_pred) // 2, dim=0 + ) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + # apply normalization after classifier-free guidance + if cfg_normalization: + cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) + noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) + noise_pred = noise_pred * (cond_norm / noise_norm) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + noise_pred = -noise_pred + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + progress_bar.update() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if not output_type == "latent": + latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) \ No newline at end of file From 81f47df6f2c079ebf1c73ae55620782fd9c5c767 Mon Sep 17 00:00:00 2001 From: Le Zhuo <53815869+zhuole1025@users.noreply.github.com> Date: Sat, 25 Jan 2025 10:34:36 +0800 Subject: [PATCH 02/25] Update src/diffusers/models/transformers/transformer_lumina2.py Co-authored-by: YiYi Xu --- src/diffusers/models/transformers/transformer_lumina2.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 55087abcfccb..cf7fbf34ec5f 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -85,11 +85,7 @@ def __init__( ffn_dim_multiplier=ffn_dim_multiplier, ) - self.norm1 = LuminaRMSNormZero( - embedding_dim=dim, - norm_eps=norm_eps, - modulation=modulation, - ) + self.norm1 = LuminaRMSNormZero() if modulation else RMSNorm(...) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.norm2 = RMSNorm(dim, eps=norm_eps) From 66ef8d700ba17276efc8aa4832bbc38d3ed10312 Mon Sep 17 00:00:00 2001 From: csuhan Date: Sat, 25 Jan 2025 19:55:43 +0800 Subject: [PATCH 03/25] fix LuminaRMSNormZero --- src/diffusers/models/normalization.py | 27 +++++++------------ .../transformers/transformer_lumina2.py | 12 +++++---- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 68c90f52c2c2..43db18233cc3 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -211,16 +211,14 @@ class LuminaRMSNormZero(nn.Module): embedding_dim (`int`): The size of each embedding vector. """ - def __init__(self, embedding_dim: int, norm_eps: float, modulation: bool): + def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool): super().__init__() - self.modulation = modulation - if modulation: - self.silu = nn.SiLU() - self.linear = nn.Linear( - min(embedding_dim, 1024), - 4 * embedding_dim, - bias=True, - ) + self.silu = nn.SiLU() + self.linear = nn.Linear( + min(embedding_dim, 1024), + 4 * embedding_dim, + bias=True, + ) self.norm = RMSNorm(embedding_dim, eps=norm_eps) def forward( @@ -228,14 +226,9 @@ def forward( x: torch.Tensor, emb: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - # emb = self.emb(timestep, encoder_hidden_states, encoder_mask) - if self.modulation: - emb = self.linear(self.silu(emb)) - scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) - x = self.norm(x) * (1 + scale_msa[:, None]) - else: - gate_msa, scale_mlp, gate_mlp = None, None, None - x = self.norm(x) + emb = self.linear(self.silu(emb)) + scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) + x = self.norm(x) * (1 + scale_msa[:, None]) return x, gate_msa, scale_mlp, gate_mlp diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 55087abcfccb..9b3ce467bd51 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -85,11 +85,13 @@ def __init__( ffn_dim_multiplier=ffn_dim_multiplier, ) - self.norm1 = LuminaRMSNormZero( - embedding_dim=dim, - norm_eps=norm_eps, - modulation=modulation, - ) + if modulation: + self.norm1 = LuminaRMSNormZero( + embedding_dim=dim, + norm_eps=norm_eps, + ) + else: + self.norm1 = RMSNorm(dim, eps=norm_eps) self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) self.norm2 = RMSNorm(dim, eps=norm_eps) From f64f7e225f5204fc75c60b72d5357245e5396178 Mon Sep 17 00:00:00 2001 From: csuhan Date: Wed, 29 Jan 2025 18:35:16 +0800 Subject: [PATCH 04/25] refactor refiner --- .../transformers/transformer_lumina2.py | 32 +++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index dce9ebf843c8..3370d70fc493 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -332,9 +332,6 @@ def patchify_and_embed( cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] - for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) - flat_x = [] for i in range(bsz): img = x[i] @@ -349,9 +346,19 @@ def patchify_and_embed( padded_img_mask[i, :l_effective_img_len[i]] = True padded_img_embed = self.x_embedder(padded_img_embed) - for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) - + + return cap_feats, padded_img_embed, img_sizes, l_effective_cap_len, l_effective_img_len, freqs_cis, max_seq_len, cap_mask, padded_img_mask + + def prepare_joint_input( + self, + cap_feats: torch.Tensor, + padded_img_embed: torch.Tensor, + max_seq_len: int, + l_effective_cap_len: list[int], + l_effective_img_len: list[int], + ) -> torch.Tensor: + bsz = cap_feats.size(0) + device = cap_feats.device mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) padded_full_embed = torch.zeros(bsz, max_seq_len, self.hidden_size, device=device, dtype=x[0].dtype) for i in range(bsz): @@ -361,8 +368,7 @@ def patchify_and_embed( padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] - return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis - + return padded_full_embed, mask def forward( self, @@ -383,8 +389,16 @@ def forward( """ temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) - hidden_states, mask, img_size, cap_size, image_rotary_emb = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb) + cap_feats, padded_img_embed, img_size, l_effective_cap_len, l_effective_img_len, image_rotary_emb, max_seq_len, cap_mask, padded_img_mask = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb) image_rotary_emb = image_rotary_emb.to(hidden_states.device) + + for layer in self.context_refiner: + cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + + for layer in self.noise_refiner: + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + + hidden_states, mask = self.prepare_joint_input(cap_feats, padded_img_embed, max_seq_len, l_effective_cap_len, l_effective_img_len) for layer in self.layers: hidden_states = layer( From e26faef8230d01954ffcfeca5f58758340b9bf14 Mon Sep 17 00:00:00 2001 From: csuhan Date: Sun, 2 Feb 2025 17:23:41 +0800 Subject: [PATCH 05/25] fix bugs and reformat rope into class --- src/diffusers/models/attention.py | 1 - src/diffusers/models/embeddings.py | 29 -- .../models/transformers/lumina_nextdit2d.py | 2 +- .../transformers/transformer_lumina2.py | 251 ++++++++++-------- 4 files changed, 136 insertions(+), 147 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4d1dae879f11..93b11c2b43f0 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -612,7 +612,6 @@ def __init__( ffn_dim_multiplier: Optional[float] = None, ): super().__init__() - inner_dim = int(2 * inner_dim / 3) # custom hidden_size factor multiplier if ffn_dim_multiplier is not None: inner_dim = int(ffn_dim_multiplier * inner_dim) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3db870e07056..afc1a170bd2d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -569,35 +569,6 @@ def forward(self, latent): return (latent + pos_embed).to(latent.dtype) -class Lumina2PosEmbed(nn.Module): - def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512)): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - self.axes_lens = axes_lens - self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) - - def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: - freqs_cis = [] - for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed( - d, - e, - theta=self.theta, - freqs_dtype=torch.float64, - ) - freqs_cis.append(emb) - return freqs_cis - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - result = [] - for i in range(len(self.axes_dim)): - freqs = self.freqs_cis[i].to(ids.device) - index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) - result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) - return torch.cat(result, dim=-1) - - class LuminaPatchEmbed(nn.Module): """ 2D Image to Patch Embedding with support for Lumina-T2X diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py index fb2b3815bcd5..320950866c4a 100644 --- a/src/diffusers/models/transformers/lumina_nextdit2d.py +++ b/src/diffusers/models/transformers/lumina_nextdit2d.py @@ -98,7 +98,7 @@ def __init__( self.feed_forward = LuminaFeedForward( dim=dim, - inner_dim=4 * dim, + inner_dim=int(4 * 2 * dim / 3), multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 3370d70fc493..915aae615aff 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List import torch import torch.nn as nn @@ -21,10 +21,7 @@ from ...utils import logging from ..attention import LuminaFeedForward from ..attention_processor import Attention, Lumina2AttnProcessor2_0 -from ..embeddings import ( - Lumina2CombinedTimestepCaptionEmbedding, - Lumina2PosEmbed, -) +from ..embeddings import Lumina2CombinedTimestepCaptionEmbedding, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm @@ -102,7 +99,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - image_rotary_emb: torch.Tensor, + rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, ): """ @@ -111,31 +108,134 @@ def forward( Parameters: hidden_states (`torch.Tensor`): The input of hidden_states for Lumina2TransformerBlock. attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. temb (`torch.Tensor`): Timestep embedding with text prompt embedding. """ - - norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) - attn_output = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_hidden_states, - attention_mask=attention_mask, - query_rotary_emb=image_rotary_emb, - key_rotary_emb=image_rotary_emb, - ) - if self.modulation: + norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=rotary_emb, + key_rotary_emb=rotary_emb, + ) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output) else: + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_hidden_states, + attention_mask=attention_mask, + query_rotary_emb=rotary_emb, + key_rotary_emb=rotary_emb, + ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states + + +class Lumina2PosEmbed(nn.Module): + def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + self.axes_lens = axes_lens + self.patch_size = patch_size + self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) + + def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: + freqs_cis = [] + for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): + emb = get_1d_rotary_pos_embed( + d, + e, + theta=self.theta, + freqs_dtype=torch.float64, + ) + freqs_cis.append(emb) + return freqs_cis + + def get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + result = [] + for i in range(len(self.axes_dim)): + freqs = self.freqs_cis[i].to(ids.device) + index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) + return torch.cat(result, dim=-1) + + def forward( + self, + x: list[torch.Tensor] | torch.Tensor, + cap_mask: torch.Tensor, + t: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], list[int], torch.Tensor]: + bsz = len(x) + pH = pW = self.patch_size + device = x[0].device + + l_effective_cap_len = cap_mask.sum(dim=1).tolist() + img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + + max_seq_len = max( + (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) + ) + max_cap_len = max(l_effective_cap_len) + max_img_len = max(l_effective_img_len) + + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + H, W = img_sizes[i] + H_tokens, W_tokens = H // pH, W // pW + assert H_tokens * W_tokens == img_len + + position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) + position_ids[i, cap_len:cap_len+img_len, 0] = cap_len + row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + position_ids[i, cap_len:cap_len+img_len, 1] = row_ids + position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + + freqs_cis = self.get_freqs_cis(position_ids) + + cap_freqs_cis_shape = list(freqs_cis.shape) + cap_freqs_cis_shape[1] = cap_mask.shape[1] + cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + img_freqs_cis_shape = list(freqs_cis.shape) + img_freqs_cis_shape[1] = max_img_len + img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) + + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + + flat_x = [] + for i in range(bsz): + img = x[i] + C, H, W = img.size() + img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) + flat_x.append(img) + x = flat_x + padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) + padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) + for i in range(bsz): + padded_img_embed[i, :l_effective_img_len[i]] = x[i] + padded_img_mask[i, :l_effective_img_len[i]] = True + + return padded_img_embed, padded_img_mask, img_sizes, l_effective_cap_len, l_effective_img_len, freqs_cis, cap_freqs_cis, img_freqs_cis, max_seq_len class Lumina2Transformer2DModel(ModelMixin, ConfigMixin): @@ -210,6 +310,7 @@ def __init__( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, + patch_size=patch_size, ) self.x_embedder = nn.Linear( in_features=patch_size * patch_size * in_channels, @@ -278,97 +379,6 @@ def __init__( # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" - - def patchify_and_embed( - self, - x: list[torch.Tensor] | torch.Tensor, - cap_feats: torch.Tensor, - cap_mask: torch.Tensor, - t: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], list[int], torch.Tensor]: - bsz = len(x) - pH = pW = self.patch_size - device = x[0].device - - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] - - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) - - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len - - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids - - freqs_cis = self.rope_embedder(position_ids) - - cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] - - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, :l_effective_img_len[i]] = True - - padded_img_embed = self.x_embedder(padded_img_embed) - - return cap_feats, padded_img_embed, img_sizes, l_effective_cap_len, l_effective_img_len, freqs_cis, max_seq_len, cap_mask, padded_img_mask - - def prepare_joint_input( - self, - cap_feats: torch.Tensor, - padded_img_embed: torch.Tensor, - max_seq_len: int, - l_effective_cap_len: list[int], - l_effective_img_len: list[int], - ) -> torch.Tensor: - bsz = cap_feats.size(0) - device = cap_feats.device - mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - padded_full_embed = torch.zeros(bsz, max_seq_len, self.hidden_size, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - mask[i, :cap_len+img_len] = True - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] - - return padded_full_embed, mask def forward( self, @@ -387,24 +397,33 @@ def forward( encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). """ - + bsz = hidden_states.size(0) + device = hidden_states.device temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) - cap_feats, padded_img_embed, img_size, l_effective_cap_len, l_effective_img_len, image_rotary_emb, max_seq_len, cap_mask, padded_img_mask = self.patchify_and_embed(hidden_states, encoder_hidden_states, encoder_mask, temb) - image_rotary_emb = image_rotary_emb.to(hidden_states.device) + padded_img_embed, padded_img_mask, img_size, l_effective_cap_len, l_effective_img_len, joint_rotary_emb, cap_rotary_emb, img_rotary_emb, max_seq_len = self.rope_embedder(hidden_states, encoder_mask, temb) + padded_img_embed = self.x_embedder(padded_img_embed) for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis) + encoder_hidden_states = layer(encoder_hidden_states, encoder_mask, cap_rotary_emb) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t) + padded_img_embed = layer(padded_img_embed, padded_img_mask, img_rotary_emb, temb) - hidden_states, mask = self.prepare_joint_input(cap_feats, padded_img_embed, max_seq_len, l_effective_cap_len, l_effective_img_len) + + mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) + hidden_states = torch.zeros(bsz, max_seq_len, self.hidden_size, device=device, dtype=hidden_states.dtype) + for i in range(bsz): + cap_len = l_effective_cap_len[i] + img_len = l_effective_img_len[i] + mask[i, :cap_len+img_len] = True + hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] + hidden_states[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] for layer in self.layers: hidden_states = layer( hidden_states, mask, - image_rotary_emb, + joint_rotary_emb, temb=temb, ) @@ -415,7 +434,7 @@ def forward( output = [] for i in range(len(img_size)): height, width = img_size[i] - begin = cap_size[i] + begin = l_effective_cap_len[i] end = begin + (height // height_tokens) * (width // width_tokens) output.append( hidden_states[i][begin:end] From 1ff61fdab3497db68917ec66775ee14911658183 Mon Sep 17 00:00:00 2001 From: csuhan Date: Mon, 3 Feb 2025 18:41:59 +0800 Subject: [PATCH 06/25] fix import normalization --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index afc1a170bd2d..b283c066bdfc 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1770,7 +1770,7 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): super().__init__() - from normalization import RMSNorm + from .normalization import RMSNorm self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 From db6fcf1c702aa4f5be79cd1a7f3ffa0bc35b35e1 Mon Sep 17 00:00:00 2001 From: csuhan Date: Mon, 3 Feb 2025 19:05:35 +0800 Subject: [PATCH 07/25] fix relateive import --- src/diffusers/models/attention_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index e2f6ae5b56b1..5c3217125666 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -4212,7 +4212,7 @@ def __call__( key_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: - from embeddings import apply_rotary_emb + from .embeddings import apply_rotary_emb input_ndim = hidden_states.ndim From 03ce67445ea88fdb89e58f45ea3b1d7039687692 Mon Sep 17 00:00:00 2001 From: csuhan Date: Wed, 5 Feb 2025 13:27:35 +0800 Subject: [PATCH 08/25] add a lot of changes --- src/diffusers/models/attention_processor.py | 97 ------- src/diffusers/models/embeddings.py | 32 --- .../transformers/transformer_lumina2.py | 258 +++++++++++++----- .../pipelines/lumina2/pipeline_lumina2.py | 66 ++--- 4 files changed, 223 insertions(+), 230 deletions(-) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 5c3217125666..30e160dd2408 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -4192,102 +4192,6 @@ def __call__( return hidden_states -class Lumina2AttnProcessor2_0: - r""" - Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector. - """ - - def __init__(self): - if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") - - def __call__( - self, - attn: Attention, - hidden_states: torch.Tensor, - encoder_hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - query_rotary_emb: Optional[torch.Tensor] = None, - key_rotary_emb: Optional[torch.Tensor] = None, - base_sequence_length: Optional[int] = None, - ) -> torch.Tensor: - from .embeddings import apply_rotary_emb - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = hidden_states.shape - - # Get Query-Key-Value Pair - query = attn.to_q(hidden_states) - key = attn.to_k(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) - - query_dim = query.shape[-1] - inner_dim = key.shape[-1] - head_dim = query_dim // attn.heads - dtype = query.dtype - - # Get key-value heads - kv_heads = inner_dim // head_dim - - query = query.view(batch_size, -1, attn.heads, head_dim) - key = key.view(batch_size, -1, kv_heads, head_dim) - value = value.view(batch_size, -1, kv_heads, head_dim) - - # Apply Query-Key Norm if needed - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # Apply RoPE if needed - if query_rotary_emb is not None: - query = apply_rotary_emb(query, query_rotary_emb, use_real=False) - if key_rotary_emb is not None: - key = apply_rotary_emb(key, key_rotary_emb, use_real=False) - - query, key = query.to(dtype), key.to(dtype) - - # Apply proportional attention if true - if base_sequence_length is not None: - softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale - else: - softmax_scale = attn.scale - - # perform Grouped-qurey Attention (GQA) - n_rep = attn.heads // kv_heads - if n_rep >= 1: - key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) - - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) - attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) - - query = query.transpose(1, 2) - key = key.transpose(1, 2) - value = value.transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, scale=softmax_scale - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(dtype) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - - return hidden_states - - class LuminaAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is @@ -6279,7 +6183,6 @@ def __call__( PAGHunyuanAttnProcessor2_0, PAGCFGHunyuanAttnProcessor2_0, LuminaAttnProcessor2_0, - Lumina2AttnProcessor2_0, FusedAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0, diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b283c066bdfc..bd3237c24c1c 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1766,38 +1766,6 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde return conditioning -class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): - def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): - super().__init__() - - from .normalization import RMSNorm - - self.time_proj = Timesteps( - num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 - ) - - self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)) - - self.caption_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear( - cap_feat_dim, - hidden_size, - bias=True, - ), - ) - - def forward(self, timestep, caption_feat): - # timestep embedding: - time_freq = self.time_proj(timestep) - time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) - - # caption condition embedding: - caption_embed = self.caption_embedder(caption_feat) - - return time_embed, caption_embed - - class LuminaCombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): super().__init__() diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 915aae615aff..6729b9705862 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -16,12 +16,14 @@ import torch import torch.nn as nn +import torch.nn.functional as F +from ...loaders import PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config from ...utils import logging from ..attention import LuminaFeedForward -from ..attention_processor import Attention, Lumina2AttnProcessor2_0 -from ..embeddings import Lumina2CombinedTimestepCaptionEmbedding, get_1d_rotary_pos_embed +from ..attention_processor import Attention +from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb, Timesteps, TimestepEmbedding from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm @@ -30,6 +32,124 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): + def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): + super().__init__() + + self.time_proj = Timesteps( + num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 + ) + + self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)) + + self.caption_embedder = nn.Sequential( + RMSNorm(cap_feat_dim, eps=norm_eps), + nn.Linear( + cap_feat_dim, + hidden_size, + bias=True, + ), + ) + + def forward(self, timestep, caption_feat): + # timestep embedding: + time_freq = self.time_proj(timestep) + time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) + + # caption condition embedding: + caption_embed = self.caption_embedder(caption_feat) + + return time_embed, caption_embed + + +class Lumina2AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + base_sequence_length: Optional[int] = None, + ) -> torch.Tensor: + + input_ndim = hidden_states.ndim + batch_size, sequence_length, _ = hidden_states.shape + + # Get Query-Key-Value Pair + query = attn.to_q(hidden_states) + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query_dim = query.shape[-1] + inner_dim = key.shape[-1] + head_dim = query_dim // attn.heads + dtype = query.dtype + + # Get key-value heads + kv_heads = inner_dim // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim) + key = key.view(batch_size, -1, kv_heads, head_dim) + value = value.view(batch_size, -1, kv_heads, head_dim) + + # Apply Query-Key Norm if needed + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, use_real=False) + key = apply_rotary_emb(key, image_rotary_emb, use_real=False) + + query, key = query.to(dtype), key.to(dtype) + + # Apply proportional attention if true + if base_sequence_length is not None: + softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale + else: + softmax_scale = attn.scale + + # perform Grouped-qurey Attention (GQA) + n_rep = attn.heads // kv_heads + if n_rep >= 1: + key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3) + + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, scale=softmax_scale + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + + return hidden_states + + class Lumina2TransformerBlock(nn.Module): """ A Lumina2TransformerBlock for Lumina2Transformer2DModel. @@ -99,7 +219,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - rotary_emb: torch.Tensor, + image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, ): """ @@ -108,7 +228,7 @@ def forward( Parameters: hidden_states (`torch.Tensor`): The input of hidden_states for Lumina2TransformerBlock. attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. + image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. temb (`torch.Tensor`): Timestep embedding with text prompt embedding. @@ -119,8 +239,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, - query_rotary_emb=rotary_emb, - key_rotary_emb=rotary_emb, + image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1))) @@ -131,8 +250,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_hidden_states, attention_mask=attention_mask, - query_rotary_emb=rotary_emb, - key_rotary_emb=rotary_emb, + image_rotary_emb=image_rotary_emb, ) hidden_states = hidden_states + self.norm2(attn_output) mlp_output = self.feed_forward(self.ffn_norm1(hidden_states)) @@ -141,7 +259,7 @@ def forward( return hidden_states -class Lumina2PosEmbed(nn.Module): +class Lumina2RotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2): super().__init__() self.theta = theta @@ -172,16 +290,15 @@ def get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: def forward( self, - x: list[torch.Tensor] | torch.Tensor, - cap_mask: torch.Tensor, - t: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor, list[tuple[int, int]], list[int], torch.Tensor]: - bsz = len(x) + hidden_states: torch.Tensor, + encoder_mask: torch.Tensor, + ): + bsz = len(hidden_states) pH = pW = self.patch_size - device = x[0].device + device = hidden_states[0].device - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - img_sizes = [(img.size(1), img.size(2)) for img in x] + l_effective_cap_len = encoder_mask.sum(dim=1).tolist() + img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] max_seq_len = max( @@ -209,7 +326,7 @@ def forward( freqs_cis = self.get_freqs_cis(position_ids) cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = cap_mask.shape[1] + cap_freqs_cis_shape[1] = encoder_mask.shape[1] cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) img_freqs_cis_shape = list(freqs_cis.shape) @@ -222,27 +339,25 @@ def forward( cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] - flat_x = [] + flat_hidden_states = [] for i in range(bsz): - img = x[i] + img = hidden_states[i] C, H, W = img.size() img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) + flat_hidden_states.append(img) + hidden_states = flat_hidden_states + padded_img_embed = torch.zeros(bsz, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype) padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] + padded_img_embed[i, :l_effective_img_len[i]] = hidden_states[i] padded_img_mask[i, :l_effective_img_len[i]] = True return padded_img_embed, padded_img_mask, img_sizes, l_effective_cap_len, l_effective_img_len, freqs_cis, cap_freqs_cis, img_freqs_cis, max_seq_len -class Lumina2Transformer2DModel(ModelMixin, ConfigMixin): +class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ - LuminaNextDiT: Diffusion model with a Transformer backbone. - - Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + Lumina2NextDiT: Diffusion model with a Transformer backbone. Parameters: sample_size (`int`): The width of the latent images. This is fixed during training since @@ -276,6 +391,8 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin): overall scale of the model's operations. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -297,16 +414,9 @@ def __init__( cap_feat_dim: Optional[int] = 1024, ) -> None: super().__init__() - self.sample_size = sample_size - self.patch_size = patch_size - self.in_channels = in_channels self.out_channels = out_channels or in_channels - self.hidden_size = hidden_size - self.num_attention_heads = num_attention_heads - self.head_dim = hidden_size // num_attention_heads - self.scaling_factor = scaling_factor - self.rope_embedder = Lumina2PosEmbed( + self.rope_embedder = Lumina2RotaryPosEmbed( theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, @@ -376,9 +486,12 @@ def __init__( bias=True, out_dim=patch_size * patch_size * self.out_channels, ) - # self.final_layer = LuminaFinalLayer(hidden_size, patch_size, self.out_channels) - - assert (hidden_size // num_attention_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4" + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value def forward( self, @@ -400,41 +513,62 @@ def forward( bsz = hidden_states.size(0) device = hidden_states.device temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) - padded_img_embed, padded_img_mask, img_size, l_effective_cap_len, l_effective_img_len, joint_rotary_emb, cap_rotary_emb, img_rotary_emb, max_seq_len = self.rope_embedder(hidden_states, encoder_mask, temb) + hidden_states, hidden_mask, hidden_sizes, encoder_hidden_len, hidden_len, joint_rotary_emb, encoder_rotary_emb, hidden_rotary_emb, max_seq_len = self.rope_embedder(hidden_states, encoder_mask) - padded_img_embed = self.x_embedder(padded_img_embed) + hidden_states = self.x_embedder(hidden_states) for layer in self.context_refiner: - encoder_hidden_states = layer(encoder_hidden_states, encoder_mask, cap_rotary_emb) + encoder_hidden_states = layer(encoder_hidden_states, encoder_mask, encoder_rotary_emb) for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_rotary_emb, temb) - + hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - hidden_states = torch.zeros(bsz, max_seq_len, self.hidden_size, device=device, dtype=hidden_states.dtype) + padded_hidden_states = torch.zeros(bsz, max_seq_len, self.config.hidden_size, device=device, dtype=hidden_states.dtype) for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] + cap_len = encoder_hidden_len[i] + img_len = hidden_len[i] mask[i, :cap_len+img_len] = True - hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] - hidden_states[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] - - for layer in self.layers: - hidden_states = layer( - hidden_states, - mask, - joint_rotary_emb, - temb=temb, - ) + padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] + padded_hidden_states[i, cap_len:cap_len+img_len] = hidden_states[i, :img_len] + hidden_states = padded_hidden_states + + if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + for layer in self.layers: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + mask, + joint_rotary_emb, + temb=temb, + ) + else: + for layer in self.layers: + hidden_states = layer( + hidden_states, + mask, + joint_rotary_emb, + temb=temb, + ) hidden_states = self.norm_out(hidden_states, temb) - # unpatchify - height_tokens = width_tokens = self.patch_size + # uspatchify + height_tokens = width_tokens = self.config.patch_size output = [] - for i in range(len(img_size)): - height, width = img_size[i] - begin = l_effective_cap_len[i] + for i in range(len(hidden_sizes)): + height, width = hidden_sizes[i] + begin = encoder_hidden_len[i] end = begin + (height // height_tokens) * (width // width_tokens) output.append( hidden_states[i][begin:end] diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index ce4eb36702fc..9561f80b41a2 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -62,7 +62,7 @@ >>> from diffusers import Lumina2Text2ImgPipeline >>> pipe = Lumina2Text2ImgPipeline.from_pretrained( - ... "Alpha-VLLM/Lumina-Next-SFT-diffusers", torch_dtype=torch.bfloat16 + ... "Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16 ... ) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -72,7 +72,7 @@ ``` """ - +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, base_seq_len: int = 256, @@ -228,14 +228,23 @@ def _get_gemma_prompt_embeds( prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - text_inputs = self.tokenizer( - prompt, - pad_to_multiple_of=8, - max_length=self.max_sequence_length, - truncation=True, - padding=True, - return_tensors="pt", - ) + if max_length is None: + text_inputs = self.tokenizer( + prompt, + pad_to_multiple_of=8, + max_length=self.max_sequence_length, + truncation=True, + padding=True, + return_tensors="pt", + ) + else: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) text_input_ids = text_inputs.input_ids.to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) @@ -348,35 +357,11 @@ def encode_prompt( f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" " the batch size of `prompt`." ) - # Padding negative prompt to the same length with prompt - prompt_max_length = prompt_embeds.shape[1] - negative_text_inputs = self.tokenizer( - negative_prompt, - padding="max_length", - max_length=prompt_max_length, - truncation=True, - return_tensors="pt", - ) - negative_text_input_ids = negative_text_inputs.input_ids.to(device) - negative_prompt_attention_mask = negative_text_inputs.attention_mask.to(device) - # Get the negative prompt embeddings - negative_prompt_embeds = self.text_encoder( - negative_text_input_ids, - attention_mask=negative_prompt_attention_mask, - output_hidden_states=True, - ) - - negative_dtype = self.text_encoder.dtype - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - _, seq_len, _ = negative_prompt_embeds.shape - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=negative_dtype, device=device) - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - negative_prompt_attention_mask = negative_prompt_attention_mask.view( - batch_size * num_images_per_prompt, -1 + negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + max_length=prompt_embeds.shape[1], ) return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask @@ -518,6 +503,9 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(device) return latents + + def enable_sequential_cpu_offload(self, *args, **kwargs): + super().enable_sequential_cpu_offload(*args, **kwargs) @property def guidance_scale(self): From 975da6a6125f63a7f51c9266a0b694266e9dce28 Mon Sep 17 00:00:00 2001 From: csuhan Date: Wed, 5 Feb 2025 16:31:13 +0800 Subject: [PATCH 09/25] add tests --- .../transformers/transformer_lumina2.py | 6 +- .../pipelines/lumina2/pipeline_lumina2.py | 4 +- .../test_models_transformer_lumina2.py | 114 ++++++++++ tests/pipelines/lumina2/__init__.py | 0 .../lumina2/test_pipeline_lumina2.py | 207 ++++++++++++++++++ 5 files changed, 327 insertions(+), 4 deletions(-) create mode 100644 tests/models/transformers/test_models_transformer_lumina2.py create mode 100644 tests/pipelines/lumina2/__init__.py create mode 100644 tests/pipelines/lumina2/test_pipeline_lumina2.py diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 6729b9705862..a0055819a184 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -20,7 +20,7 @@ from ...loaders import PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging +from ...utils import logging, is_torch_version from ..attention import LuminaFeedForward from ..attention_processor import Attention from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb, Timesteps, TimestepEmbedding @@ -550,7 +550,7 @@ def custom_forward(*inputs): hidden_states, mask, joint_rotary_emb, - temb=temb, + temb, ) else: for layer in self.layers: @@ -558,7 +558,7 @@ def custom_forward(*inputs): hidden_states, mask, joint_rotary_emb, - temb=temb, + temb, ) hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 9561f80b41a2..67bc4ad7990f 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -328,7 +328,8 @@ def encode_prompt( if system_prompt is None: system_prompt = self.system_prompt - prompt = [system_prompt + ' ' + p for p in prompt] + if prompt is not None: + prompt = [system_prompt + ' ' + p for p in prompt] if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( @@ -630,6 +631,7 @@ def __call__( """ height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor + self._guidance_scale = guidance_scale # 1. Check inputs. Raise error if not correct self.check_inputs( diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py new file mode 100644 index 000000000000..4b01c4726dfa --- /dev/null +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -0,0 +1,114 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch + +from diffusers import Lumina2Transformer2DModel +from diffusers.utils.testing_utils import ( + enable_full_determinism, + torch_device, +) + +from ..test_modeling_common import ModelTesterMixin + + +enable_full_determinism() + + +class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase): + model_class = Lumina2Transformer2DModel + main_input_name = "hidden_states" + uses_custom_attn_processor = True + + @property + def dummy_input(self): + """ + Args: + None + Returns: + Dict: Dictionary of dummy input tensors + """ + batch_size = 2 # N + num_channels = 4 # C + height = width = 16 # H, W + embedding_dim = 32 # D + sequence_length = 16 # L + + hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) + timestep = torch.rand(size=(batch_size,)).to(torch_device) + encoder_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "timestep": timestep, + "encoder_mask": encoder_mask, + } + + @property + def input_shape(self): + """ + Args: + None + Returns: + Tuple: (int, int, int) + """ + return (4, 16, 16) + + @property + def output_shape(self): + """ + Args: + None + Returns: + Tuple: (int, int, int) + """ + return (4, 16, 16) + + def prepare_init_args_and_inputs_for_common(self): + """ + Args: + None + + Returns: + Tuple: (Dict, Dict) + """ + init_dict = { + "sample_size": 16, + "patch_size": 2, + "in_channels": 4, + "hidden_size": 24, + "num_layers": 2, + "num_refiner_layers": 1, + "num_attention_heads": 3, + "num_kv_heads": 1, + "multiple_of": 2, + "ffn_dim_multiplier": None, + "norm_eps": 1e-5, + "scaling_factor": 1.0, + "axes_dim_rope": (4, 2, 2), + "axes_lens": (128, 128, 128), + "cap_feat_dim": 32, + } + + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing_is_applied(self): + expected_set = {"Lumina2Transformer2DModel"} + super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/lumina2/__init__.py b/tests/pipelines/lumina2/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py new file mode 100644 index 000000000000..91a526b91dae --- /dev/null +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -0,0 +1,207 @@ +import gc +import unittest + +import numpy as np +import pytest +import torch +from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM + +from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Transformer2DModel, Lumina2Text2ImgPipeline +from diffusers.utils.testing_utils import ( + nightly, + numpy_cosine_similarity_distance, + require_big_gpu_with_torch_cuda, + torch_device, +) + +from ..test_pipelines_common import PipelineTesterMixin + + +class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin): + pipeline_class = Lumina2Text2ImgPipeline + params = frozenset( + [ + "prompt", + "height", + "width", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + ] + ) + batch_params = frozenset(["prompt", "negative_prompt"]) + required_optional_params = frozenset( + [ + "num_inference_steps", + "generator", + "latents", + "return_dict", + "callback_on_step_end", + "callback_on_step_end_tensor_inputs", + ] + ) + + test_xformers_attention = False + test_layerwise_casting = True + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = Lumina2Transformer2DModel( + sample_size=4, + patch_size=2, + in_channels=4, + hidden_size=8, + num_layers=2, + num_attention_heads=1, + num_kv_heads=1, + multiple_of=16, + ffn_dim_multiplier=None, + norm_eps=1e-5, + scaling_factor=1.0, + axes_dim_rope=[4, 2, 2], + cap_feat_dim=8, + ) + + torch.manual_seed(0) + vae = AutoencoderKL( + sample_size=32, + in_channels=3, + out_channels=3, + block_out_channels=(4,), + layers_per_block=1, + latent_channels=4, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + shift_factor=0.0609, + scaling_factor=1.5035, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma") + + torch.manual_seed(0) + config = GemmaConfig( + head_dim=2, + hidden_size=8, + intermediate_size=37, + num_attention_heads=4, + num_hidden_layers=2, + num_key_value_heads=4, + ) + text_encoder = GemmaForCausalLM(config) + + components = { + "transformer": transformer.eval(), + "vae": vae.eval(), + "scheduler": scheduler, + "text_encoder": text_encoder.eval(), + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "output_type": "np", + } + return inputs + + def test_lumina_prompt_embeds(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + output_with_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = inputs.pop("prompt") + + do_classifier_free_guidance = inputs["guidance_scale"] > 1 + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = pipe.encode_prompt( + prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + device=torch_device, + ) + output_with_embeds = pipe( + prompt_embeds=prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + **inputs, + ).images[0] + + max_diff = np.abs(output_with_prompt - output_with_embeds).max() + assert max_diff < 1e-4 + + +@nightly +@require_big_gpu_with_torch_cuda +@pytest.mark.big_gpu_with_torch_cuda +class Lumina2Text2ImgPipelineSlowTests(unittest.TestCase): + pipeline_class = Lumina2Text2ImgPipeline + repo_id = "Alpha-VLLM/Lumina-Image-2.0" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return { + "prompt": "A photo of a cat", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "generator": generator, + } + + def test_lumina_inference(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) + pipe.enable_model_cpu_offload() + + inputs = self.get_inputs(torch_device) + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + [0.17773438, 0.18554688, 0.22070312], + [0.046875, 0.06640625, 0.10351562], + [0.0, 0.0, 0.02148438], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float32, + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4 From 174077b3381dfc746d38f3e151c26bf32f99d208 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 06:43:07 +0100 Subject: [PATCH 10/25] make style --- src/diffusers/__init__.py | 8 +- src/diffusers/models/__init__.py | 4 +- .../transformers/transformer_lumina2.py | 144 +++++++++++------- .../pipelines/lumina/pipeline_lumina.py | 2 +- .../pipelines/lumina2/pipeline_lumina2.py | 59 ++++--- .../test_models_transformer_lumina2.py | 2 +- .../lumina2/test_pipeline_lumina2.py | 9 +- 7 files changed, 128 insertions(+), 100 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c8438d8e9c93..b93ee537f7dd 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -109,8 +109,8 @@ "Kandinsky3UNet", "LatteTransformer3DModel", "LTXVideoTransformer3DModel", - "LuminaNextDiT2DModel", "Lumina2Transformer2DModel", + "LuminaNextDiT2DModel", "MochiTransformer3DModel", "ModelMixin", "MotionAdapter", @@ -329,8 +329,8 @@ "LEditsPPPipelineStableDiffusionXL", "LTXImageToVideoPipeline", "LTXPipeline", - "LuminaText2ImgPipeline", "Lumina2Text2ImgPipeline", + "LuminaText2ImgPipeline", "MarigoldDepthPipeline", "MarigoldNormalsPipeline", "MochiPipeline", @@ -623,8 +623,8 @@ Kandinsky3UNet, LatteTransformer3DModel, LTXVideoTransformer3DModel, - LuminaNextDiT2DModel, Lumina2Transformer2DModel, + LuminaNextDiT2DModel, MochiTransformer3DModel, ModelMixin, MotionAdapter, @@ -822,8 +822,8 @@ LEditsPPPipelineStableDiffusionXL, LTXImageToVideoPipeline, LTXPipeline, - LuminaText2ImgPipeline, Lumina2Text2ImgPipeline, + LuminaText2ImgPipeline, MarigoldDepthPipeline, MarigoldNormalsPipeline, MochiPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 04618d9a9baf..3763252747bd 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -60,7 +60,6 @@ _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"] _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] - _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"] @@ -72,6 +71,7 @@ _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] + _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] @@ -139,8 +139,8 @@ HunyuanVideoTransformer3DModel, LatteTransformer3DModel, LTXVideoTransformer3DModel, - LuminaNextDiT2DModel, Lumina2Transformer2DModel, + LuminaNextDiT2DModel, MochiTransformer3DModel, PixArtTransformer2DModel, PriorTransformer, diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a0055819a184..2270acf79e02 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, List +import math +from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F -from ...loaders import PeftAdapterMixin from ...configuration_utils import ConfigMixin, register_to_config -from ...utils import logging, is_torch_version +from ...loaders import PeftAdapterMixin +from ...utils import logging from ..attention import LuminaFeedForward from ..attention_processor import Attention -from ..embeddings import get_1d_rotary_pos_embed, apply_rotary_emb, Timesteps, TimestepEmbedding +from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm @@ -35,12 +36,14 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): super().__init__() - + self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 ) - self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)) + self.timestep_embedder = TimestepEmbedding( + in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024) + ) self.caption_embedder = nn.Sequential( RMSNorm(cap_feat_dim, eps=norm_eps), @@ -60,12 +63,13 @@ def forward(self, timestep, caption_feat): caption_embed = self.caption_embedder(caption_feat) return time_embed, caption_embed - - + + class Lumina2AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector. + used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and + key vector. """ def __init__(self): @@ -81,8 +85,6 @@ def __call__( image_rotary_emb: Optional[torch.Tensor] = None, base_sequence_length: Optional[int] = None, ) -> torch.Tensor: - - input_ndim = hidden_states.ndim batch_size, sequence_length, _ = hidden_states.shape # Get Query-Key-Value Pair @@ -101,13 +103,13 @@ def __call__( query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, kv_heads, head_dim) value = value.view(batch_size, -1, kv_heads, head_dim) - + # Apply Query-Key Norm if needed if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) - + # Apply RoPE if needed if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, use_real=False) @@ -146,10 +148,10 @@ def __call__( # linear proj hidden_states = attn.to_out[0](hidden_states) - + return hidden_states - - + + class Lumina2TransformerBlock(nn.Module): """ A Lumina2TransformerBlock for Lumina2Transformer2DModel. @@ -194,14 +196,14 @@ def __init__( out_bias=False, processor=Lumina2AttnProcessor2_0(), ) - + self.feed_forward = LuminaFeedForward( dim=dim, inner_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, ) - + if modulation: self.norm1 = LuminaRMSNormZero( embedding_dim=dim, @@ -257,8 +259,8 @@ def forward( hidden_states = hidden_states + self.ffn_norm2(mlp_output) return hidden_states - - + + class Lumina2RotaryPosEmbed(nn.Module): def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2): super().__init__() @@ -267,7 +269,7 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, self.axes_lens = axes_lens self.patch_size = patch_size self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) - + def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): @@ -279,15 +281,15 @@ def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: ) freqs_cis.append(emb) return freqs_cis - + def get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: result = [] for i in range(len(self.axes_dim)): freqs = self.freqs_cis[i].to(ids.device) - index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) + index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64) result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1) - + def forward( self, hidden_states: torch.Tensor, @@ -301,14 +303,11 @@ def forward( img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) + max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) max_img_len = max(l_effective_img_len) - + position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) - + for i in range(bsz): cap_len = l_effective_cap_len[i] img_len = l_effective_img_len[i] @@ -317,14 +316,18 @@ def forward( assert H_tokens * W_tokens == img_len position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + position_ids[i, cap_len : cap_len + img_len, 0] = cap_len + row_ids = ( + torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() + ) + col_ids = ( + torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() + ) + position_ids[i, cap_len : cap_len + img_len, 1] = row_ids + position_ids[i, cap_len : cap_len + img_len, 2] = col_ids freqs_cis = self.get_freqs_cis(position_ids) - + cap_freqs_cis_shape = list(freqs_cis.shape) cap_freqs_cis_shape[1] = encoder_mask.shape[1] cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) @@ -332,13 +335,13 @@ def forward( img_freqs_cis_shape = list(freqs_cis.shape) img_freqs_cis_shape[1] = max_img_len img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - + for i in range(bsz): cap_len = l_effective_cap_len[i] img_len = l_effective_img_len[i] cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] - + img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] + flat_hidden_states = [] for i in range(bsz): img = hidden_states[i] @@ -346,13 +349,25 @@ def forward( img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) flat_hidden_states.append(img) hidden_states = flat_hidden_states - padded_img_embed = torch.zeros(bsz, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype) + padded_img_embed = torch.zeros( + bsz, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype + ) padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = hidden_states[i] - padded_img_mask[i, :l_effective_img_len[i]] = True - - return padded_img_embed, padded_img_mask, img_sizes, l_effective_cap_len, l_effective_img_len, freqs_cis, cap_freqs_cis, img_freqs_cis, max_seq_len + padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] + padded_img_mask[i, : l_effective_img_len[i]] = True + + return ( + padded_img_embed, + padded_img_mask, + img_sizes, + l_effective_cap_len, + l_effective_img_len, + freqs_cis, + cap_freqs_cis, + img_freqs_cis, + max_seq_len, + ) class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): @@ -392,7 +407,7 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True - + @register_to_config def __init__( self, @@ -448,7 +463,7 @@ def __init__( for _ in range(num_refiner_layers) ] ) - + self.context_refiner = nn.ModuleList( [ Lumina2TransformerBlock( @@ -463,7 +478,7 @@ def __init__( for _ in range(num_refiner_layers) ] ) - + self.layers = nn.ModuleList( [ Lumina2TransformerBlock( @@ -486,9 +501,9 @@ def __init__( bias=True, out_dim=patch_size * patch_size * self.out_channels, ) - + self.gradient_checkpointing = False - + def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @@ -513,26 +528,39 @@ def forward( bsz = hidden_states.size(0) device = hidden_states.device temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) - hidden_states, hidden_mask, hidden_sizes, encoder_hidden_len, hidden_len, joint_rotary_emb, encoder_rotary_emb, hidden_rotary_emb, max_seq_len = self.rope_embedder(hidden_states, encoder_mask) - + ( + hidden_states, + hidden_mask, + hidden_sizes, + encoder_hidden_len, + hidden_len, + joint_rotary_emb, + encoder_rotary_emb, + hidden_rotary_emb, + max_seq_len, + ) = self.rope_embedder(hidden_states, encoder_mask) + hidden_states = self.x_embedder(hidden_states) for layer in self.context_refiner: encoder_hidden_states = layer(encoder_hidden_states, encoder_mask, encoder_rotary_emb) - + for layer in self.noise_refiner: hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) - + mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - padded_hidden_states = torch.zeros(bsz, max_seq_len, self.config.hidden_size, device=device, dtype=hidden_states.dtype) + padded_hidden_states = torch.zeros( + bsz, max_seq_len, self.config.hidden_size, device=device, dtype=hidden_states.dtype + ) for i in range(bsz): cap_len = encoder_hidden_len[i] img_len = hidden_len[i] - mask[i, :cap_len+img_len] = True + mask[i, : cap_len + img_len] = True padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len] - padded_hidden_states[i, cap_len:cap_len+img_len] = hidden_states[i, :img_len] + padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] hidden_states = padded_hidden_states if torch.is_grad_enabled() and self.gradient_checkpointing: + def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: @@ -542,8 +570,6 @@ def custom_forward(*inputs): return custom_forward - ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - for layer in self.layers: hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), @@ -582,4 +608,4 @@ def custom_forward(*inputs): if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) \ No newline at end of file + return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index ac94312bf786..5b37e9a503a8 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -17,7 +17,7 @@ import math import re import urllib.parse as ul -from typing import List, Optional, Tuple, Union, Callable, Dict +from typing import List, Optional, Tuple, Union import torch from transformers import AutoModel, AutoTokenizer diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 67bc4ad7990f..78b4206555b8 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import html import inspect import math -import numpy as np import re -import urllib.parse as ul -from typing import List, Optional, Tuple, Union, Callable, Dict +from typing import Callable, Dict, List, Optional, Tuple, Union +import numpy as np import torch from transformers import AutoModel, AutoTokenizer @@ -28,7 +26,6 @@ from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import ( - BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, is_torch_xla_available, @@ -50,10 +47,10 @@ if is_bs4_available(): - from bs4 import BeautifulSoup + pass if is_ftfy_available(): - import ftfy + pass EXAMPLE_DOC_STRING = """ Examples: @@ -61,9 +58,7 @@ >>> import torch >>> from diffusers import Lumina2Text2ImgPipeline - >>> pipe = Lumina2Text2ImgPipeline.from_pretrained( - ... "Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16 - ... ) + >>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16) >>> # Enable memory optimizations. >>> pipe.enable_model_cpu_offload() @@ -72,6 +67,7 @@ ``` """ + # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift def calculate_shift( image_seq_len, @@ -325,11 +321,11 @@ def encode_prompt( batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] - + if system_prompt is None: system_prompt = self.system_prompt if prompt is not None: - prompt = [system_prompt + ' ' + p for p in prompt] + prompt = [system_prompt + " " + p for p in prompt] if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( @@ -407,7 +403,7 @@ def check_inputs( raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) - + if prompt is not None and prompt_embeds is not None: raise ValueError( f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" @@ -451,10 +447,10 @@ def check_inputs( f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" f" {negative_prompt_attention_mask.shape}." ) - + if max_sequence_length is not None and max_sequence_length > 512: raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") - + def enable_vae_slicing(self): r""" Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to @@ -483,7 +479,7 @@ def disable_vae_tiling(self): computing decoding in one step. """ self.vae.disable_tiling() - + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. @@ -491,7 +487,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) - + if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -504,7 +500,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents.to(device) return latents - + def enable_sequential_cpu_offload(self, *args, **kwargs): super().enable_sequential_cpu_offload(*args, **kwargs) @@ -663,7 +659,7 @@ def __call__( # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - + self.tokenizer.padding_side = "right" # 3. Encode input prompt @@ -721,12 +717,17 @@ def __call__( ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - + # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance and 1 - t.item() / self.scheduler.config.num_train_timesteps < cfg_trunc_ratio else latents + latent_model_input = ( + torch.cat([latents] * 2) + if do_classifier_free_guidance + and 1 - t.item() / self.scheduler.config.num_train_timesteps < cfg_trunc_ratio + else latents + ) current_timestep = t if not torch.is_tensor(current_timestep): # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can @@ -748,7 +749,7 @@ def __call__( # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps - + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=current_timestep, @@ -759,18 +760,14 @@ def __call__( # perform normalization-based guidance scale on a truncated timestep interval if do_classifier_free_guidance and current_timestep[0] < cfg_trunc_ratio: - noise_pred_cond, noise_pred_uncond = torch.split( - noise_pred, len(noise_pred) // 2, dim=0 - ) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_cond - noise_pred_uncond - ) + noise_pred_cond, noise_pred_uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance if cfg_normalization: cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_pred = noise_pred * (cond_norm / noise_norm) - + # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype noise_pred = -noise_pred @@ -798,7 +795,7 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() - + if not output_type == "latent": latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents, return_dict=False)[0] @@ -812,4 +809,4 @@ def __call__( if not return_dict: return (image,) - return ImagePipelineOutput(images=image) \ No newline at end of file + return ImagePipelineOutput(images=image) diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py index 4b01c4726dfa..57af6b537506 100644 --- a/tests/models/transformers/test_models_transformer_lumina2.py +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -108,7 +108,7 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict - + def test_gradient_checkpointing_is_applied(self): expected_set = {"Lumina2Transformer2DModel"} super().test_gradient_checkpointing_is_applied(expected_set=expected_set) diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index 91a526b91dae..b046686b6a00 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -6,7 +6,12 @@ import torch from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM -from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, Lumina2Transformer2DModel, Lumina2Text2ImgPipeline +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + Lumina2Text2ImgPipeline, + Lumina2Transformer2DModel, +) from diffusers.utils.testing_utils import ( nightly, numpy_cosine_similarity_distance, @@ -62,7 +67,7 @@ def get_dummy_components(self): axes_dim_rope=[4, 2, 2], cap_feat_dim=8, ) - + torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, From 75e5c31a6792a04b16fe1d55de0305303b845bf8 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 07:25:37 +0100 Subject: [PATCH 11/25] refactor transformer --- .../transformers/transformer_lumina2.py | 224 ++++++------------ .../pipelines/lumina2/pipeline_lumina2.py | 74 +++--- 2 files changed, 106 insertions(+), 192 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 2270acf79e02..3e5f12ed828f 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -34,7 +34,13 @@ class Lumina2CombinedTimestepCaptionEmbedding(nn.Module): - def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5): + def __init__( + self, + hidden_size: int = 4096, + cap_feat_dim: int = 2048, + frequency_embedding_size: int = 256, + norm_eps: float = 1e-5, + ) -> None: super().__init__() self.time_proj = Timesteps( @@ -46,30 +52,22 @@ def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size ) self.caption_embedder = nn.Sequential( - RMSNorm(cap_feat_dim, eps=norm_eps), - nn.Linear( - cap_feat_dim, - hidden_size, - bias=True, - ), + RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True) ) - def forward(self, timestep, caption_feat): - # timestep embedding: - time_freq = self.time_proj(timestep) - time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) - - # caption condition embedding: - caption_embed = self.caption_embedder(caption_feat) - + def forward( + self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + timestep_proj = self.time_proj(timestep).type_as(hidden_states) + time_embed = self.timestep_embedder(timestep_proj) + caption_embed = self.caption_embedder(encoder_hidden_states) return time_embed, caption_embed class Lumina2AttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is - used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and - key vector. + used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors. """ def __init__(self): @@ -138,37 +136,19 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, scale=softmax_scale ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(dtype) + hidden_states = hidden_states.type_as(query) # linear proj hidden_states = attn.to_out[0](hidden_states) - + hidden_states = attn.to_out[1](hidden_states) return hidden_states class Lumina2TransformerBlock(nn.Module): - """ - A Lumina2TransformerBlock for Lumina2Transformer2DModel. - - Parameters: - dim (`int`): Embedding dimension of the input features. - num_attention_heads (`int`): Number of attention heads. - num_kv_heads (`int`): - Number of attention heads in key and value features (if using GQA), or set to None for the same as query. - multiple_of (`int`): The number of multiple of ffn layer. - ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. - norm_eps (`float`): The eps for norm layer. - qk_norm (`bool`): normalization for query and key. - cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. - modulation (`bool`): Whether to use modulation. - """ - def __init__( self, dim: int, @@ -183,7 +163,6 @@ def __init__( self.head_dim = dim // num_attention_heads self.modulation = modulation - # Self-attention self.attn = Attention( query_dim=dim, cross_attention_dim=None, @@ -223,18 +202,7 @@ def forward( attention_mask: torch.Tensor, image_rotary_emb: torch.Tensor, temb: Optional[torch.Tensor] = None, - ): - """ - Perform a forward pass through the Lumina2TransformerBlock. - - Parameters: - hidden_states (`torch.Tensor`): The input of hidden_states for Lumina2TransformerBlock. - attention_mask (`torch.Tensor): The input of hidden_states corresponse attention mask. - image_rotary_emb (`torch.Tensor`): Precomputed cosine and sine frequencies. - encoder_hidden_states: (`torch.Tensor`): The hidden_states of text prompt are processed by Gemma encoder. - encoder_mask (`torch.Tensor`): The hidden_states of text prompt attention mask. - temb (`torch.Tensor`): Timestep embedding with text prompt embedding. - """ + ) -> torch.Tensor: if self.modulation: norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) attn_output = self.attn( @@ -268,21 +236,17 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, self.axes_dim = axes_dim self.axes_lens = axes_lens self.patch_size = patch_size - self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta) - def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: + self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta) + + def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed( - d, - e, - theta=self.theta, - freqs_dtype=torch.float64, - ) + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) freqs_cis.append(emb) return freqs_cis - def get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: + def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: result = [] for i in range(len(self.axes_dim)): freqs = self.freqs_cis[i].to(ids.device) @@ -290,29 +254,26 @@ def get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor: result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index)) return torch.cat(result, dim=-1) - def forward( - self, - hidden_states: torch.Tensor, - encoder_mask: torch.Tensor, - ): - bsz = len(hidden_states) - pH = pW = self.patch_size + def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor): + batch_size = len(hidden_states) + p_h = p_w = self.patch_size device = hidden_states[0].device - l_effective_cap_len = encoder_mask.sum(dim=1).tolist() + l_effective_cap_len = attention_mask.sum(dim=1).tolist() + # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape img_sizes = [(img.size(1), img.size(2)) for img in hidden_states] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes] max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))) max_img_len = max(l_effective_img_len) - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device) - for i in range(bsz): + for i in range(batch_size): cap_len = l_effective_cap_len[i] img_len = l_effective_img_len[i] H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW + H_tokens, W_tokens = H // p_h, W // p_w assert H_tokens * W_tokens == img_len position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) @@ -326,34 +287,34 @@ def forward( position_ids[i, cap_len : cap_len + img_len, 1] = row_ids position_ids[i, cap_len : cap_len + img_len, 2] = col_ids - freqs_cis = self.get_freqs_cis(position_ids) + freqs_cis = self._get_freqs_cis(position_ids) cap_freqs_cis_shape = list(freqs_cis.shape) - cap_freqs_cis_shape[1] = encoder_mask.shape[1] + cap_freqs_cis_shape[1] = attention_mask.shape[1] cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) img_freqs_cis_shape = list(freqs_cis.shape) img_freqs_cis_shape[1] = max_img_len img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - for i in range(bsz): + for i in range(batch_size): cap_len = l_effective_cap_len[i] img_len = l_effective_img_len[i] cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len] flat_hidden_states = [] - for i in range(bsz): + for i in range(batch_size): img = hidden_states[i] C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) + img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) flat_hidden_states.append(img) hidden_states = flat_hidden_states padded_img_embed = torch.zeros( - bsz, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype + batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype ) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=torch.bool, device=device) - for i in range(bsz): + padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device) + for i in range(batch_size): padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i] padded_img_mask[i, : l_effective_img_len[i]] = True @@ -371,7 +332,7 @@ def forward( class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): - """ + r""" Lumina2NextDiT: Diffusion model with a Transformer backbone. Parameters: @@ -407,6 +368,8 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): """ _supports_gradient_checkpointing = True + _no_split_modules = ["Lumina2TransformerBlock"] + _skip_layerwise_casting_patterns = ["x_embedder", "norm"] @register_to_config def __init__( @@ -431,24 +394,18 @@ def __init__( super().__init__() self.out_channels = out_channels or in_channels + # 1. Positional, patch & conditional embeddings self.rope_embedder = Lumina2RotaryPosEmbed( - theta=10000, - axes_dim=axes_dim_rope, - axes_lens=axes_lens, - patch_size=patch_size, - ) - self.x_embedder = nn.Linear( - in_features=patch_size * patch_size * in_channels, - out_features=hidden_size, - bias=True, + theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size ) + self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size) + self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding( - hidden_size=hidden_size, - cap_feat_dim=cap_feat_dim, - norm_eps=norm_eps, + hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps ) + # 2. Noise and context refinement blocks self.noise_refiner = nn.ModuleList( [ Lumina2TransformerBlock( @@ -479,6 +436,7 @@ def __init__( ] ) + # 3. Transformer blocks self.layers = nn.ModuleList( [ Lumina2TransformerBlock( @@ -493,6 +451,8 @@ def __init__( for _ in range(num_layers) ] ) + + # 4. Output norm & projection self.norm_out = LuminaLayerNormContinuous( embedding_dim=hidden_size, conditioning_embedding_dim=min(hidden_size, 1024), @@ -504,30 +464,19 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - def forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, - encoder_mask: torch.Tensor, - return_dict=True, - ) -> torch.Tensor: - """ - Forward pass of LuminaNextDiT. - - Parameters: - hidden_states (torch.Tensor): Input tensor of shape (N, C, H, W). - timestep (torch.Tensor): Tensor of diffusion timesteps of shape (N,). - encoder_hidden_states (torch.Tensor): Tensor of caption features of shape (N, D). - encoder_mask (torch.Tensor): Tensor of caption masks of shape (N, L). - """ - bsz = hidden_states.size(0) - device = hidden_states.device - temb, encoder_hidden_states = self.time_caption_embed(timestep, encoder_hidden_states) + attention_mask: torch.Tensor, + return_dict: bool = True, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + batch_size = hidden_states.size(0) + + # 1. Condition, positional & patch embedding + temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states) + ( hidden_states, hidden_mask, @@ -538,20 +487,21 @@ def forward( encoder_rotary_emb, hidden_rotary_emb, max_seq_len, - ) = self.rope_embedder(hidden_states, encoder_mask) + ) = self.rope_embedder(hidden_states, attention_mask) hidden_states = self.x_embedder(hidden_states) + + # 2. Context & noise refinement for layer in self.context_refiner: - encoder_hidden_states = layer(encoder_hidden_states, encoder_mask, encoder_rotary_emb) + encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb) for layer in self.noise_refiner: hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) - mask = torch.zeros(bsz, max_seq_len, dtype=torch.bool, device=device) - padded_hidden_states = torch.zeros( - bsz, max_seq_len, self.config.hidden_size, device=device, dtype=hidden_states.dtype - ) - for i in range(bsz): + # 3. Attention mask preparation + mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) + padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size) + for i in range(batch_size): cap_len = encoder_hidden_len[i] img_len = hidden_len[i] mask[i, : cap_len + img_len] = True @@ -559,37 +509,16 @@ def forward( padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len] hidden_states = padded_hidden_states - if torch.is_grad_enabled() and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - for layer in self.layers: - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - mask, - joint_rotary_emb, - temb, - ) - else: - for layer in self.layers: - hidden_states = layer( - hidden_states, - mask, - joint_rotary_emb, - temb, - ) + # 4. Transformer blocks + for layer in self.layers: + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb) + else: + hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb) + # 5. Output norm & projection & unpatchify hidden_states = self.norm_out(hidden_states, temb) - # uspatchify height_tokens = width_tokens = self.config.patch_size output = [] for i in range(len(hidden_sizes)): @@ -607,5 +536,4 @@ def custom_forward(*inputs): if not return_dict: return (output,) - return Transformer2DModelOutput(sample=output) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 78b4206555b8..78152b602717 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import math import re from typing import Callable, Dict, List, Optional, Tuple, Union @@ -203,8 +202,6 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 8 - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) - self.max_sequence_length = 256 self.default_sample_size = ( self.transformer.config.sample_size if hasattr(self, "transformer") and self.transformer is not None @@ -213,42 +210,38 @@ def __init__( self.default_image_size = self.default_sample_size * self.vae_scale_factor self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts." + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, device: Optional[torch.device] = None, - max_length: Optional[int] = None, - ): + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - if max_length is None: - text_inputs = self.tokenizer( - prompt, - pad_to_multiple_of=8, - max_length=self.max_sequence_length, - truncation=True, - padding=True, - return_tensors="pt", - ) - else: - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.to(device) untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device) if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1 : -1]) + removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) logger.warning( "The following part of your input was truncated because Gemma can only handle sequences up to" - f" {self.max_sequence_length} tokens: {removed_text}" + f" {max_sequence_length} tokens: {removed_text}" ) prompt_attention_mask = text_inputs.attention_mask.to(device) @@ -288,8 +281,8 @@ def encode_prompt( prompt_attention_mask: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, system_prompt: Optional[str] = None, - **kwargs, - ): + max_sequence_length: int = 256, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: r""" Encodes the prompt into text encoder hidden states. @@ -311,7 +304,8 @@ def encode_prompt( provided, text embeddings will be generated from `prompt` input argument. negative_prompt_embeds (`torch.Tensor`, *optional*): Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string. - max_sequence_length (`int`, defaults to 256): Maximum sequence length to use for the prompt. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use for the prompt. """ if device is None: device = self._execution_device @@ -332,6 +326,7 @@ def encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, device=device, + max_sequence_length=max_sequence_length, ) # Get negative embeddings for classifier free guidance @@ -358,7 +353,7 @@ def encode_prompt( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, - max_length=prompt_embeds.shape[1], + max_sequence_length=max_sequence_length, ) return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask @@ -501,9 +496,6 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents - def enable_sequential_cpu_offload(self, *args, **kwargs): - super().enable_sequential_cpu_offload(*args, **kwargs) - @property def guidance_scale(self): return self._guidance_scale @@ -539,12 +531,12 @@ def __call__( negative_prompt_attention_mask: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, - max_sequence_length: int = 256, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], system_prompt: Optional[str] = None, cfg_trunc_ratio: float = 1.0, cfg_normalization: bool = True, + max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -600,8 +592,6 @@ def __call__( [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. - max_sequence_length (`int` defaults to 120): - Maximum sequence length to use with the `prompt`. callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -613,10 +603,12 @@ def __call__( `._callback_tensor_inputs` attribute of your pipeline class. system_prompt (`str`, *optional*): The system prompt to use for the image generation. - cfg_trunc_ratio (`float`, *optional*, defaults to 1.0): + cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`): The ratio of the timestep interval to apply normalization-based guidance scale. - cfg_normalization (`bool`, *optional*, defaults to True): + cfg_normalization (`bool`, *optional*, defaults to `True`): Whether to apply normalization-based guidance scale. + max_sequence_length (`int`, defaults to `256`): + Maximum sequence length to use with the `prompt`. Examples: @@ -651,8 +643,6 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - scaling_factor = math.sqrt(width * height / self.default_image_size**2) - device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -660,8 +650,6 @@ def __call__( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - self.tokenizer.padding_side = "right" - # 3. Encode input prompt ( prompt_embeds, @@ -744,9 +732,9 @@ def __call__( ) elif len(current_timestep.shape) == 0: current_timestep = current_timestep[None].to(latent_model_input.device) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) - # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps @@ -754,7 +742,7 @@ def __call__( hidden_states=latent_model_input, timestep=current_timestep, encoder_hidden_states=prompt_embeds, - encoder_mask=prompt_attention_mask, + attention_mask=prompt_attention_mask, return_dict=False, )[0] @@ -778,8 +766,6 @@ def __call__( # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 latents = latents.to(latents_dtype) - progress_bar.update() - if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: From aeef34045308f7fa0729a5824f9a69839825d95c Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 07:28:13 +0100 Subject: [PATCH 12/25] fix import; remove lumina2 integration test --- src/diffusers/models/transformers/__init__.py | 1 + .../lumina2/test_pipeline_lumina2.py | 68 +------------------ 2 files changed, 2 insertions(+), 67 deletions(-) diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index 77e1698b8fc2..e36e929e1522 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -21,6 +21,7 @@ from .transformer_flux import FluxTransformer2DModel from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel from .transformer_ltx import LTXVideoTransformer3DModel + from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_mochi import MochiTransformer3DModel from .transformer_sd3 import SD3Transformer2DModel from .transformer_temporal import TransformerTemporalModel diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index b046686b6a00..74dabc0642e9 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -1,8 +1,6 @@ -import gc import unittest import numpy as np -import pytest import torch from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM @@ -12,12 +10,7 @@ Lumina2Text2ImgPipeline, Lumina2Transformer2DModel, ) -from diffusers.utils.testing_utils import ( - nightly, - numpy_cosine_similarity_distance, - require_big_gpu_with_torch_cuda, - torch_device, -) +from diffusers.utils.testing_utils import torch_device from ..test_pipelines_common import PipelineTesterMixin @@ -151,62 +144,3 @@ def test_lumina_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 - - -@nightly -@require_big_gpu_with_torch_cuda -@pytest.mark.big_gpu_with_torch_cuda -class Lumina2Text2ImgPipelineSlowTests(unittest.TestCase): - pipeline_class = Lumina2Text2ImgPipeline - repo_id = "Alpha-VLLM/Lumina-Image-2.0" - - def setUp(self): - super().setUp() - gc.collect() - torch.cuda.empty_cache() - - def tearDown(self): - super().tearDown() - gc.collect() - torch.cuda.empty_cache() - - def get_inputs(self, device, seed=0): - if str(device).startswith("mps"): - generator = torch.manual_seed(seed) - else: - generator = torch.Generator(device="cpu").manual_seed(seed) - - return { - "prompt": "A photo of a cat", - "num_inference_steps": 2, - "guidance_scale": 5.0, - "output_type": "np", - "generator": generator, - } - - def test_lumina_inference(self): - pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16) - pipe.enable_model_cpu_offload() - - inputs = self.get_inputs(torch_device) - image = pipe(**inputs).images[0] - image_slice = image[0, :10, :10] - expected_slice = np.array( - [ - [0.17773438, 0.18554688, 0.22070312], - [0.046875, 0.06640625, 0.10351562], - [0.0, 0.0, 0.02148438], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ], - dtype=np.float32, - ) - - max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) - - assert max_diff < 1e-4 From e6f6aae69062e95be623fca735bc48d19efd7008 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 07:34:35 +0100 Subject: [PATCH 13/25] update docs --- docs/source/en/_toctree.yml | 2 ++ .../en/api/models/lumina2_transformer2d.md | 30 +++++++++++++++++++ docs/source/en/api/pipelines/lumina2.md | 29 ++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 docs/source/en/api/models/lumina2_transformer2d.md create mode 100644 docs/source/en/api/pipelines/lumina2.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 752219b4abd1..13f7cd412542 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -288,6 +288,8 @@ title: LatteTransformer3DModel - local: api/models/lumina_nextdit2d title: LuminaNextDiT2DModel + - local: api/models/lumina2_transformer2d + title: Lumina2Transformer2DModel - local: api/models/ltx_video_transformer3d title: LTXVideoTransformer3DModel - local: api/models/mochi_transformer3d diff --git a/docs/source/en/api/models/lumina2_transformer2d.md b/docs/source/en/api/models/lumina2_transformer2d.md new file mode 100644 index 000000000000..0d7c0585dcd5 --- /dev/null +++ b/docs/source/en/api/models/lumina2_transformer2d.md @@ -0,0 +1,30 @@ + + +# Lumina2Transformer2DModel + +A Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM. + +The model can be loaded with the following code snippet. + +```python +from diffusers import Lumina2Transformer2DModel + +transformer = Lumina2Transformer2DModel.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", subfolder="transformer", torch_dtype=torch.bfloat16) +``` + +## Lumina2Transformer2DModel + +[[autodoc]] Lumina2Transformer2DModel + +## Transformer2DModelOutput + +[[autodoc]] models.modeling_outputs.Transformer2DModelOutput diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md new file mode 100644 index 000000000000..7800283b0bfd --- /dev/null +++ b/docs/source/en/api/pipelines/lumina2.md @@ -0,0 +1,29 @@ + + +# HunyuanVideo + +[Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM. + + + +Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines. + + + +## Lumina2Text2ImgPipeline + +[[autodoc]] Lumina2Text2ImgPipeline + - all + - __call__ From e79b7160fba71884881e140efea65c0cc37c2e74 Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 6 Feb 2025 07:37:51 +0100 Subject: [PATCH 14/25] update tests --- docs/source/en/api/pipelines/lumina2.md | 2 +- .../test_models_transformer_lumina2.py | 29 ++----------------- 2 files changed, 3 insertions(+), 28 deletions(-) diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index 7800283b0bfd..e9d346302cb4 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. --> -# HunyuanVideo +# Lumina2 [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM. diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py index 57af6b537506..e89f160433bd 100644 --- a/tests/models/transformers/test_models_transformer_lumina2.py +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -36,12 +36,6 @@ class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestC @property def dummy_input(self): - """ - Args: - None - Returns: - Dict: Dictionary of dummy input tensors - """ batch_size = 2 # N num_channels = 4 # C height = width = 16 # H, W @@ -51,43 +45,24 @@ def dummy_input(self): hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) timestep = torch.rand(size=(batch_size,)).to(torch_device) - encoder_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device) + attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device) return { "hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep, - "encoder_mask": encoder_mask, + "attention_mask": attention_mask, } @property def input_shape(self): - """ - Args: - None - Returns: - Tuple: (int, int, int) - """ return (4, 16, 16) @property def output_shape(self): - """ - Args: - None - Returns: - Tuple: (int, int, int) - """ return (4, 16, 16) def prepare_init_args_and_inputs_for_common(self): - """ - Args: - None - - Returns: - Tuple: (Dict, Dict) - """ init_dict = { "sample_size": 16, "patch_size": 2, From cd32a8c3a514321844d4373f8b372cc782547f5e Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 7 Feb 2025 09:13:39 +0100 Subject: [PATCH 15/25] update --- docs/source/en/_toctree.yml | 2 ++ .../transformers/transformer_lumina2.py | 26 +++++++++---------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 13f7cd412542..e954534a4172 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -440,6 +440,8 @@ title: LEDITS++ - local: api/pipelines/ltx_video title: LTXVideo + - local: api/pipelines/lumina2 + title: Lumina 2.0 - local: api/pipelines/lumina title: Lumina-T2X - local: api/pipelines/marigold diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 3e5f12ed828f..582fcee77e44 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -375,21 +375,21 @@ class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): def __init__( self, sample_size: int = 128, - patch_size: Optional[int] = 2, - in_channels: Optional[int] = 16, + patch_size: int = 2, + in_channels: int = 16, out_channels: Optional[int] = None, - hidden_size: Optional[int] = 2304, - num_layers: Optional[int] = 26, - num_refiner_layers: Optional[int] = 2, - num_attention_heads: Optional[int] = 24, - num_kv_heads: Optional[int] = 8, - multiple_of: Optional[int] = 256, + hidden_size: int = 2304, + num_layers: int = 26, + num_refiner_layers: int = 2, + num_attention_heads: int = 24, + num_kv_heads: int = 8, + multiple_of: int = 256, ffn_dim_multiplier: Optional[float] = None, - norm_eps: Optional[float] = 1e-5, - scaling_factor: Optional[float] = 1.0, - axes_dim_rope: Optional[tuple[int, int, int]] = (32, 32, 32), - axes_lens: Optional[tuple[int, int, int]] = (300, 512, 512), - cap_feat_dim: Optional[int] = 1024, + norm_eps: float = 1e-5, + scaling_factor: float = 1.0, + axes_dim_rope: Tuple[int, int, int] = (32, 32, 32), + axes_lens: Tuple[int, int, int] = (300, 512, 512), + cap_feat_dim: int = 1024, ) -> None: super().__init__() self.out_channels = out_channels or in_channels From 67dabb025c4926aa3e17ede6f01a1f0d6ed68414 Mon Sep 17 00:00:00 2001 From: Aryan Date: Fri, 7 Feb 2025 09:14:02 +0100 Subject: [PATCH 16/25] make fix-copies --- src/diffusers/utils/dummy_pt_objects.py | 15 +++++++++++++++ .../utils/dummy_torch_and_transformers_objects.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 6a1978944c9f..654c78539f07 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -531,6 +531,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class Lumina2Transformer2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LuminaNextDiT2DModel(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index b899915c3046..19017c86eb93 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Lumina2Text2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class LuminaText2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] From 3f81149743f0ceeed651729a2488bc6506fd74cb Mon Sep 17 00:00:00 2001 From: csuhan Date: Sat, 8 Feb 2025 14:29:39 +0800 Subject: [PATCH 17/25] final small fix --- docs/source/en/api/pipelines/lumina2.md | 6 ++- .../pipelines/lumina2/pipeline_lumina2.py | 40 +++++++------------ 2 files changed, 20 insertions(+), 26 deletions(-) diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md index e9d346302cb4..fbd822af783e 100644 --- a/docs/source/en/api/pipelines/lumina2.md +++ b/docs/source/en/api/pipelines/lumina2.md @@ -14,7 +14,11 @@ # Lumina2 -[Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM. +[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions. + +The abstract from the paper is: + +*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.* diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 78152b602717..55b59d2d1fd8 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -164,22 +164,6 @@ class Lumina2Text2ImgPipeline(DiffusionPipeline): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ - bad_punct_regex = re.compile( - r"[" - + "#®•©™&@·º½¾¿¡§~" - + r"\)" - + r"\(" - + r"\]" - + r"\[" - + r"\}" - + r"\{" - + r"\|" - + "\\" - + r"\/" - + r"\*" - + r"]{1,}" - ) # noqa - _optional_components = [] _callback_tensor_inputs = ["latents", "prompt_embeds"] model_cpu_offload_seq = "text_encoder->transformer->vae" @@ -218,7 +202,6 @@ def __init__( def _get_gemma_prompt_embeds( self, prompt: Union[str, List[str]], - num_images_per_prompt: int = 1, device: Optional[torch.device] = None, max_sequence_length: int = 256, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -260,11 +243,6 @@ def _get_gemma_prompt_embeds( prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) - prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, prompt_attention_mask @@ -324,10 +302,16 @@ def encode_prompt( if prompt_embeds is None: prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=prompt, - num_images_per_prompt=num_images_per_prompt, device=device, max_sequence_length=max_sequence_length, ) + + batch_size, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) # Get negative embeddings for classifier free guidance if do_classifier_free_guidance and negative_prompt_embeds is None: @@ -351,10 +335,16 @@ def encode_prompt( ) negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds( prompt=negative_prompt, - num_images_per_prompt=num_images_per_prompt, device=device, max_sequence_length=max_sequence_length, ) + + batch_size, seq_len, _ = negative_prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask @@ -694,7 +684,7 @@ def __call__( self.scheduler.config.get("base_image_seq_len", 256), self.scheduler.config.get("max_image_seq_len", 4096), self.scheduler.config.get("base_shift", 0.5), - self.scheduler.config.get("max_shift", 1.16), + self.scheduler.config.get("max_shift", 1.15), ) timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, From 4cbc06581c150c4db2f0ebe30137cd0b2df8275a Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 8 Feb 2025 08:36:04 +0100 Subject: [PATCH 18/25] 'make style' to fix quality test --- src/diffusers/pipelines/lumina2/pipeline_lumina2.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 55b59d2d1fd8..1d02c746aba2 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -13,7 +13,6 @@ # limitations under the License. import inspect -import re from typing import Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -207,8 +206,6 @@ def _get_gemma_prompt_embeds( ) -> Tuple[torch.Tensor, torch.Tensor]: device = device or self._execution_device prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - text_inputs = self.tokenizer( prompt, padding="max_length", @@ -305,7 +302,7 @@ def encode_prompt( device=device, max_sequence_length=max_sequence_length, ) - + batch_size, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -338,13 +335,15 @@ def encode_prompt( device=device, max_sequence_length=max_sequence_length, ) - + batch_size, seq_len, _ = negative_prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1) - negative_prompt_attention_mask = negative_prompt_attention_mask.view(batch_size * num_images_per_prompt, -1) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + batch_size * num_images_per_prompt, -1 + ) return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask From b11cc609bd320d5d379aa03f3ac10c8fdb1b409c Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 8 Feb 2025 10:41:28 +0100 Subject: [PATCH 19/25] supports_dduf=False --- tests/pipelines/lumina2/test_pipeline_lumina2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py index 74dabc0642e9..f8e0667ce1d2 100644 --- a/tests/pipelines/lumina2/test_pipeline_lumina2.py +++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py @@ -40,6 +40,7 @@ class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTester ] ) + supports_dduf = False test_xformers_attention = False test_layerwise_casting = True From 565fa0cc5caca2f52ed081e45f48506929ae427a Mon Sep 17 00:00:00 2001 From: csuhan Date: Sun, 9 Feb 2025 12:17:57 +0800 Subject: [PATCH 20/25] fix cfg cpu gpu delay --- .../pipelines/lumina2/pipeline_lumina2.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 1d02c746aba2..397141f7d6c5 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -698,30 +698,17 @@ def __call__( # 6. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + # compute whether apply classifier-free truncation on this timestep + do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio + # expand the latents if we are doing classifier free guidance latent_model_input = ( torch.cat([latents] * 2) - if do_classifier_free_guidance - and 1 - t.item() / self.scheduler.config.num_train_timesteps < cfg_trunc_ratio + if do_classifier_free_guidance and not do_classifier_free_truncation else latents ) + current_timestep = t - if not torch.is_tensor(current_timestep): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = latent_model_input.device.type == "mps" - if isinstance(current_timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - current_timestep = torch.tensor( - [current_timestep], - dtype=dtype, - device=latent_model_input.device, - ) - elif len(current_timestep.shape) == 0: - current_timestep = current_timestep[None].to(latent_model_input.device) - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image @@ -736,7 +723,7 @@ def __call__( )[0] # perform normalization-based guidance scale on a truncated timestep interval - if do_classifier_free_guidance and current_timestep[0] < cfg_trunc_ratio: + if self.do_classifier_free_guidance and not do_classifier_free_truncation: noise_pred_cond, noise_pred_uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance From c1b0f3d12b7d296fe09312ce82dce9807f30e19e Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Feb 2025 06:05:20 +0000 Subject: [PATCH 21/25] no expand attn mask, two forward pass --- src/diffusers/models/normalization.py | 14 ++++----- .../transformers/transformer_lumina2.py | 12 ++++---- .../pipelines/lumina2/pipeline_lumina2.py | 30 ++++++++----------- 3 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 43db18233cc3..4c5bfbda3e33 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -515,19 +515,15 @@ def forward(self, hidden_states): if self.bias is not None: hidden_states = hidden_states + self.bias else: - input_dtype = hidden_states.dtype - variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: hidden_states = hidden_states.to(self.weight.dtype) - hidden_states = hidden_states * self.weight - if self.bias is not None: - hidden_states = hidden_states + self.bias - else: - hidden_states = hidden_states.to(input_dtype) + hidden_states = nn.functional.rms_norm( + hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps + ) + if self.bias is not None: + hidden_states = hidden_states + self.bias return hidden_states diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 582fcee77e44..292b3725f115 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -130,7 +130,6 @@ def __call__( # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) - attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -493,10 +492,12 @@ def forward( # 2. Context & noise refinement for layer in self.context_refiner: - encoder_hidden_states = layer(encoder_hidden_states, attention_mask, encoder_rotary_emb) + # NOTE: mask not used for performance + encoder_hidden_states = layer(encoder_hidden_states, None, encoder_rotary_emb) for layer in self.noise_refiner: - hidden_states = layer(hidden_states, hidden_mask, hidden_rotary_emb, temb) + # NOTE: mask not used for performance + hidden_states = layer(hidden_states, None, hidden_rotary_emb, temb) # 3. Attention mask preparation mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) @@ -511,10 +512,11 @@ def forward( # 4. Transformer blocks for layer in self.layers: + # NOTE: mask not used for performance if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(layer, hidden_states, mask, joint_rotary_emb, temb) + hidden_states = self._gradient_checkpointing_func(layer, hidden_states, None, joint_rotary_emb, temb) else: - hidden_states = layer(hidden_states, mask, joint_rotary_emb, temb) + hidden_states = layer(hidden_states, None, joint_rotary_emb, temb) # 5. Output norm & projection & unpatchify hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 397141f7d6c5..c82d45c5eba0 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -658,9 +658,6 @@ def __call__( max_sequence_length=max_sequence_length, system_prompt=system_prompt, ) - if do_classifier_free_guidance: - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) - prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask], dim=0) # 4. Prepare latents. latent_channels = self.transformer.config.in_channels @@ -700,22 +697,13 @@ def __call__( for i, t in enumerate(timesteps): # compute whether apply classifier-free truncation on this timestep do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio - - # expand the latents if we are doing classifier free guidance - latent_model_input = ( - torch.cat([latents] * 2) - if do_classifier_free_guidance and not do_classifier_free_truncation - else latents - ) - - current_timestep = t - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - current_timestep = current_timestep.expand(latent_model_input.shape[0]) # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image - current_timestep = 1 - current_timestep / self.scheduler.config.num_train_timesteps + current_timestep = 1 - t / self.scheduler.config.num_train_timesteps + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, + noise_pred_cond = self.transformer( + hidden_states=latents, timestep=current_timestep, encoder_hidden_states=prompt_embeds, attention_mask=prompt_attention_mask, @@ -724,7 +712,13 @@ def __call__( # perform normalization-based guidance scale on a truncated timestep interval if self.do_classifier_free_guidance and not do_classifier_free_truncation: - noise_pred_cond, noise_pred_uncond = torch.split(noise_pred, len(noise_pred) // 2, dim=0) + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=current_timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_mask=negative_prompt_attention_mask, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance if cfg_normalization: From 7c5a46f46f68b7df1c1d4aa51837a5903a70f28d Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Feb 2025 06:46:43 +0000 Subject: [PATCH 22/25] use_mask_in_transformer, is_torch_version --- src/diffusers/models/normalization.py | 16 +++++++++++++++- .../models/transformers/transformer_lumina2.py | 15 +++++++++++---- .../pipelines/lumina2/pipeline_lumina2.py | 5 +++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 4c5bfbda3e33..da8951ebafea 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -514,7 +514,7 @@ def forward(self, hidden_states): hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] if self.bias is not None: hidden_states = hidden_states + self.bias - else: + elif is_torch_version(">=", "2.4"): if self.weight is not None: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: @@ -524,6 +524,20 @@ def forward(self, hidden_states): ) if self.bias is not None: hidden_states = hidden_states + self.bias + else: + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + if self.bias is not None: + hidden_states = hidden_states + self.bias + else: + hidden_states = hidden_states.to(input_dtype) return hidden_states diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 292b3725f115..be4439933c67 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -469,6 +469,7 @@ def forward( timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, + use_mask_in_transformer: bool = True, return_dict: bool = True, ) -> Union[torch.Tensor, Transformer2DModelOutput]: batch_size = hidden_states.size(0) @@ -493,11 +494,15 @@ def forward( # 2. Context & noise refinement for layer in self.context_refiner: # NOTE: mask not used for performance - encoder_hidden_states = layer(encoder_hidden_states, None, encoder_rotary_emb) + encoder_hidden_states = layer( + encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb + ) for layer in self.noise_refiner: # NOTE: mask not used for performance - hidden_states = layer(hidden_states, None, hidden_rotary_emb, temb) + hidden_states = layer( + hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb + ) # 3. Attention mask preparation mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool) @@ -514,9 +519,11 @@ def forward( for layer in self.layers: # NOTE: mask not used for performance if torch.is_grad_enabled() and self.gradient_checkpointing: - hidden_states = self._gradient_checkpointing_func(layer, hidden_states, None, joint_rotary_emb, temb) + hidden_states = self._gradient_checkpointing_func( + layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb + ) else: - hidden_states = layer(hidden_states, None, joint_rotary_emb, temb) + hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb) # 5. Output norm & projection & unpatchify hidden_states = self.norm_out(hidden_states, temb) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index c82d45c5eba0..cbd17a7db359 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -525,6 +525,7 @@ def __call__( system_prompt: Optional[str] = None, cfg_trunc_ratio: float = 1.0, cfg_normalization: bool = True, + use_mask_in_transformer: bool = True, max_sequence_length: int = 256, ) -> Union[ImagePipelineOutput, Tuple]: """ @@ -596,6 +597,8 @@ def __call__( The ratio of the timestep interval to apply normalization-based guidance scale. cfg_normalization (`bool`, *optional*, defaults to `True`): Whether to apply normalization-based guidance scale. + use_mask_in_transformer (`bool`, *optional*, defaults to `True`): + Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain. max_sequence_length (`int`, defaults to `256`): Maximum sequence length to use with the `prompt`. @@ -707,6 +710,7 @@ def __call__( timestep=current_timestep, encoder_hidden_states=prompt_embeds, attention_mask=prompt_attention_mask, + use_mask_in_transformer=use_mask_in_transformer, return_dict=False, )[0] @@ -717,6 +721,7 @@ def __call__( timestep=current_timestep, encoder_hidden_states=negative_prompt_embeds, attention_mask=negative_prompt_attention_mask, + use_mask_in_transformer=use_mask_in_transformer, return_dict=False, )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) From ff4bff9a3fcb1b5d7adfdc29f825d7f4f48bcaf3 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 10 Feb 2025 23:18:45 +0000 Subject: [PATCH 23/25] fix use_mask_in_transformer=False --- src/diffusers/models/transformers/transformer_lumina2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index be4439933c67..50a6ebb224a5 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -129,7 +129,8 @@ def __call__( # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) + if attention_mask is not None: + attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1) query = query.transpose(1, 2) key = key.transpose(1, 2) From da6fd60ad06cff1c6f59e8615eb6d4f42c024caf Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 11 Feb 2025 19:49:13 +0100 Subject: [PATCH 24/25] fix --- src/diffusers/pipelines/lumina2/pipeline_lumina2.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index cbd17a7db359..801ed25093a3 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -637,11 +637,6 @@ def __call__( device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt ( prompt_embeds, @@ -650,7 +645,7 @@ def __call__( negative_prompt_attention_mask, ) = self.encode_prompt( prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, device=device, @@ -730,6 +725,8 @@ def __call__( cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True) noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True) noise_pred = noise_pred * (cond_norm / noise_norm) + else: + noise_pred = noise_pred_cond # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From 9ea9bbebf6c6045b0ce5b3faf35ec23b3b6c8135 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 11 Feb 2025 20:28:46 +0100 Subject: [PATCH 25/25] fix mps --- src/diffusers/models/transformers/transformer_lumina2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index 50a6ebb224a5..bd0848a2d63f 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -241,8 +241,10 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]: freqs_cis = [] + # Use float32 for MPS compatibility + dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 for i, (d, e) in enumerate(zip(axes_dim, axes_lens)): - emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=torch.float64) + emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype) freqs_cis.append(emb) return freqs_cis