diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 752219b4abd1..e954534a4172 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -288,6 +288,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
+ - local: api/models/lumina2_transformer2d
+ title: Lumina2Transformer2DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/mochi_transformer3d
@@ -438,6 +440,8 @@
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTXVideo
+ - local: api/pipelines/lumina2
+ title: Lumina 2.0
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
diff --git a/docs/source/en/api/models/lumina2_transformer2d.md b/docs/source/en/api/models/lumina2_transformer2d.md
new file mode 100644
index 000000000000..0d7c0585dcd5
--- /dev/null
+++ b/docs/source/en/api/models/lumina2_transformer2d.md
@@ -0,0 +1,30 @@
+
+
+# Lumina2Transformer2DModel
+
+A Diffusion Transformer model for 3D video-like data was introduced in [Lumina Image 2.0](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) by Alpha-VLLM.
+
+The model can be loaded with the following code snippet.
+
+```python
+from diffusers import Lumina2Transformer2DModel
+
+transformer = Lumina2Transformer2DModel.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", subfolder="transformer", torch_dtype=torch.bfloat16)
+```
+
+## Lumina2Transformer2DModel
+
+[[autodoc]] Lumina2Transformer2DModel
+
+## Transformer2DModelOutput
+
+[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
diff --git a/docs/source/en/api/pipelines/lumina2.md b/docs/source/en/api/pipelines/lumina2.md
new file mode 100644
index 000000000000..fbd822af783e
--- /dev/null
+++ b/docs/source/en/api/pipelines/lumina2.md
@@ -0,0 +1,33 @@
+
+
+# Lumina2
+
+[Lumina Image 2.0: A Unified and Efficient Image Generative Model](https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0) is a 2 billion parameter flow-based diffusion transformer capable of generating diverse images from text descriptions.
+
+The abstract from the paper is:
+
+*We introduce Lumina-Image 2.0, an advanced text-to-image model that surpasses previous state-of-the-art methods across multiple benchmarks, while also shedding light on its potential to evolve into a generalist vision intelligence model. Lumina-Image 2.0 exhibits three key properties: (1) Unification – it adopts a unified architecture that treats text and image tokens as a joint sequence, enabling natural cross-modal interactions and facilitating task expansion. Besides, since high-quality captioners can provide semantically better-aligned text-image training pairs, we introduce a unified captioning system, UniCaptioner, which generates comprehensive and precise captions for the model. This not only accelerates model convergence but also enhances prompt adherence, variable-length prompt handling, and task generalization via prompt templates. (2) Efficiency – to improve the efficiency of the unified architecture, we develop a set of optimization techniques that improve semantic learning and fine-grained texture generation during training while incorporating inference-time acceleration strategies without compromising image quality. (3) Transparency – we open-source all training details, code, and models to ensure full reproducibility, aiming to bridge the gap between well-resourced closed-source research teams and independent developers.*
+
+
+
+Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
+
+
+
+## Lumina2Text2ImgPipeline
+
+[[autodoc]] Lumina2Text2ImgPipeline
+ - all
+ - __call__
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index c36226225ad4..4bf8d0345f1f 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -118,6 +118,7 @@
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
+ "Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
@@ -337,6 +338,7 @@
"LEditsPPPipelineStableDiffusionXL",
"LTXImageToVideoPipeline",
"LTXPipeline",
+ "Lumina2Text2ImgPipeline",
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
@@ -632,6 +634,7 @@
Kandinsky3UNet,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
+ Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
@@ -830,6 +833,7 @@
LEditsPPPipelineStableDiffusionXL,
LTXImageToVideoPipeline,
LTXPipeline,
+ Lumina2Text2ImgPipeline,
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 57a34609d28e..d20b5bc82025 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -72,6 +72,7 @@
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
+ _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
@@ -140,6 +141,7 @@
HunyuanVideoTransformer3DModel,
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
+ Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py
index 4d1dae879f11..93b11c2b43f0 100644
--- a/src/diffusers/models/attention.py
+++ b/src/diffusers/models/attention.py
@@ -612,7 +612,6 @@ def __init__(
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
- inner_dim = int(2 * inner_dim / 3)
# custom hidden_size factor multiplier
if ffn_dim_multiplier is not None:
inner_dim = int(ffn_dim_multiplier * inner_dim)
diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py
index 7db4d3d17d2f..da8951ebafea 100644
--- a/src/diffusers/models/normalization.py
+++ b/src/diffusers/models/normalization.py
@@ -219,14 +219,13 @@ def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine:
4 * embedding_dim,
bias=True,
)
- self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
+ self.norm = RMSNorm(embedding_dim, eps=norm_eps)
def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- # emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
@@ -515,6 +514,16 @@ def forward(self, hidden_states):
hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
if self.bias is not None:
hidden_states = hidden_states + self.bias
+ elif is_torch_version(">=", "2.4"):
+ if self.weight is not None:
+ # convert into half-precision if necessary
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.to(self.weight.dtype)
+ hidden_states = nn.functional.rms_norm(
+ hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps
+ )
+ if self.bias is not None:
+ hidden_states = hidden_states + self.bias
else:
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py
index 77e1698b8fc2..e36e929e1522 100644
--- a/src/diffusers/models/transformers/__init__.py
+++ b/src/diffusers/models/transformers/__init__.py
@@ -21,6 +21,7 @@
from .transformer_flux import FluxTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_ltx import LTXVideoTransformer3DModel
+ from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
diff --git a/src/diffusers/models/transformers/lumina_nextdit2d.py b/src/diffusers/models/transformers/lumina_nextdit2d.py
index fb2b3815bcd5..320950866c4a 100644
--- a/src/diffusers/models/transformers/lumina_nextdit2d.py
+++ b/src/diffusers/models/transformers/lumina_nextdit2d.py
@@ -98,7 +98,7 @@ def __init__(
self.feed_forward = LuminaFeedForward(
dim=dim,
- inner_dim=4 * dim,
+ inner_dim=int(4 * 2 * dim / 3),
multiple_of=multiple_of,
ffn_dim_multiplier=ffn_dim_multiplier,
)
diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py
new file mode 100644
index 000000000000..bd0848a2d63f
--- /dev/null
+++ b/src/diffusers/models/transformers/transformer_lumina2.py
@@ -0,0 +1,551 @@
+# Copyright 2024 Alpha-VLLM Authors and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...loaders import PeftAdapterMixin
+from ...utils import logging
+from ..attention import LuminaFeedForward
+from ..attention_processor import Attention
+from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed
+from ..modeling_outputs import Transformer2DModelOutput
+from ..modeling_utils import ModelMixin
+from ..normalization import LuminaLayerNormContinuous, LuminaRMSNormZero, RMSNorm
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
+ def __init__(
+ self,
+ hidden_size: int = 4096,
+ cap_feat_dim: int = 2048,
+ frequency_embedding_size: int = 256,
+ norm_eps: float = 1e-5,
+ ) -> None:
+ super().__init__()
+
+ self.time_proj = Timesteps(
+ num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
+ )
+
+ self.timestep_embedder = TimestepEmbedding(
+ in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024)
+ )
+
+ self.caption_embedder = nn.Sequential(
+ RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, hidden_size, bias=True)
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ timestep_proj = self.time_proj(timestep).type_as(hidden_states)
+ time_embed = self.timestep_embedder(timestep_proj)
+ caption_embed = self.caption_embedder(encoder_hidden_states)
+ return time_embed, caption_embed
+
+
+class Lumina2AttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
+ used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors.
+ """
+
+ def __init__(self):
+ if not hasattr(F, "scaled_dot_product_attention"):
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ image_rotary_emb: Optional[torch.Tensor] = None,
+ base_sequence_length: Optional[int] = None,
+ ) -> torch.Tensor:
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ # Get Query-Key-Value Pair
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ query_dim = query.shape[-1]
+ inner_dim = key.shape[-1]
+ head_dim = query_dim // attn.heads
+ dtype = query.dtype
+
+ # Get key-value heads
+ kv_heads = inner_dim // head_dim
+
+ query = query.view(batch_size, -1, attn.heads, head_dim)
+ key = key.view(batch_size, -1, kv_heads, head_dim)
+ value = value.view(batch_size, -1, kv_heads, head_dim)
+
+ # Apply Query-Key Norm if needed
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if image_rotary_emb is not None:
+ query = apply_rotary_emb(query, image_rotary_emb, use_real=False)
+ key = apply_rotary_emb(key, image_rotary_emb, use_real=False)
+
+ query, key = query.to(dtype), key.to(dtype)
+
+ # Apply proportional attention if true
+ if base_sequence_length is not None:
+ softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
+ else:
+ softmax_scale = attn.scale
+
+ # perform Grouped-qurey Attention (GQA)
+ n_rep = attn.heads // kv_heads
+ if n_rep >= 1:
+ key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ if attention_mask is not None:
+ attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
+
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ value = value.transpose(1, 2)
+
+ hidden_states = F.scaled_dot_product_attention(
+ query, key, value, attn_mask=attention_mask, scale=softmax_scale
+ )
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
+ hidden_states = hidden_states.type_as(query)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ hidden_states = attn.to_out[1](hidden_states)
+ return hidden_states
+
+
+class Lumina2TransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ num_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ modulation: bool = True,
+ ) -> None:
+ super().__init__()
+ self.head_dim = dim // num_attention_heads
+ self.modulation = modulation
+
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=None,
+ dim_head=dim // num_attention_heads,
+ qk_norm="rms_norm",
+ heads=num_attention_heads,
+ kv_heads=num_kv_heads,
+ eps=1e-5,
+ bias=False,
+ out_bias=False,
+ processor=Lumina2AttnProcessor2_0(),
+ )
+
+ self.feed_forward = LuminaFeedForward(
+ dim=dim,
+ inner_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ )
+
+ if modulation:
+ self.norm1 = LuminaRMSNormZero(
+ embedding_dim=dim,
+ norm_eps=norm_eps,
+ norm_elementwise_affine=True,
+ )
+ else:
+ self.norm1 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
+
+ self.norm2 = RMSNorm(dim, eps=norm_eps)
+ self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ image_rotary_emb: torch.Tensor,
+ temb: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ if self.modulation:
+ norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + gate_msa.unsqueeze(1).tanh() * self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states) * (1 + scale_mlp.unsqueeze(1)))
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(mlp_output)
+ else:
+ norm_hidden_states = self.norm1(hidden_states)
+ attn_output = self.attn(
+ hidden_states=norm_hidden_states,
+ encoder_hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ image_rotary_emb=image_rotary_emb,
+ )
+ hidden_states = hidden_states + self.norm2(attn_output)
+ mlp_output = self.feed_forward(self.ffn_norm1(hidden_states))
+ hidden_states = hidden_states + self.ffn_norm2(mlp_output)
+
+ return hidden_states
+
+
+class Lumina2RotaryPosEmbed(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512), patch_size: int = 2):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+ self.axes_lens = axes_lens
+ self.patch_size = patch_size
+
+ self.freqs_cis = self._precompute_freqs_cis(axes_dim, axes_lens, theta)
+
+ def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
+ freqs_cis = []
+ # Use float32 for MPS compatibility
+ dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
+ for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
+ emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
+ freqs_cis.append(emb)
+ return freqs_cis
+
+ def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
+ result = []
+ for i in range(len(self.axes_dim)):
+ freqs = self.freqs_cis[i].to(ids.device)
+ index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
+ result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
+ return torch.cat(result, dim=-1)
+
+ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
+ batch_size = len(hidden_states)
+ p_h = p_w = self.patch_size
+ device = hidden_states[0].device
+
+ l_effective_cap_len = attention_mask.sum(dim=1).tolist()
+ # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
+ img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
+ l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
+
+ max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
+ max_img_len = max(l_effective_img_len)
+
+ position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
+
+ for i in range(batch_size):
+ cap_len = l_effective_cap_len[i]
+ img_len = l_effective_img_len[i]
+ H, W = img_sizes[i]
+ H_tokens, W_tokens = H // p_h, W // p_w
+ assert H_tokens * W_tokens == img_len
+
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
+ position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
+ row_ids = (
+ torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
+ )
+ col_ids = (
+ torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
+ )
+ position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
+ position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
+
+ freqs_cis = self._get_freqs_cis(position_ids)
+
+ cap_freqs_cis_shape = list(freqs_cis.shape)
+ cap_freqs_cis_shape[1] = attention_mask.shape[1]
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
+
+ img_freqs_cis_shape = list(freqs_cis.shape)
+ img_freqs_cis_shape[1] = max_img_len
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
+
+ for i in range(batch_size):
+ cap_len = l_effective_cap_len[i]
+ img_len = l_effective_img_len[i]
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
+
+ flat_hidden_states = []
+ for i in range(batch_size):
+ img = hidden_states[i]
+ C, H, W = img.size()
+ img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
+ flat_hidden_states.append(img)
+ hidden_states = flat_hidden_states
+ padded_img_embed = torch.zeros(
+ batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
+ )
+ padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
+ for i in range(batch_size):
+ padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
+ padded_img_mask[i, : l_effective_img_len[i]] = True
+
+ return (
+ padded_img_embed,
+ padded_img_mask,
+ img_sizes,
+ l_effective_cap_len,
+ l_effective_img_len,
+ freqs_cis,
+ cap_freqs_cis,
+ img_freqs_cis,
+ max_seq_len,
+ )
+
+
+class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
+ r"""
+ Lumina2NextDiT: Diffusion model with a Transformer backbone.
+
+ Parameters:
+ sample_size (`int`): The width of the latent images. This is fixed during training since
+ it is used to learn a number of position embeddings.
+ patch_size (`int`, *optional*, (`int`, *optional*, defaults to 2):
+ The size of each patch in the image. This parameter defines the resolution of patches fed into the model.
+ in_channels (`int`, *optional*, defaults to 4):
+ The number of input channels for the model. Typically, this matches the number of channels in the input
+ images.
+ hidden_size (`int`, *optional*, defaults to 4096):
+ The dimensionality of the hidden layers in the model. This parameter determines the width of the model's
+ hidden representations.
+ num_layers (`int`, *optional*, default to 32):
+ The number of layers in the model. This defines the depth of the neural network.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ The number of attention heads in each attention layer. This parameter specifies how many separate attention
+ mechanisms are used.
+ num_kv_heads (`int`, *optional*, defaults to 8):
+ The number of key-value heads in the attention mechanism, if different from the number of attention heads.
+ If None, it defaults to num_attention_heads.
+ multiple_of (`int`, *optional*, defaults to 256):
+ A factor that the hidden size should be a multiple of. This can help optimize certain hardware
+ configurations.
+ ffn_dim_multiplier (`float`, *optional*):
+ A multiplier for the dimensionality of the feed-forward network. If None, it uses a default value based on
+ the model configuration.
+ norm_eps (`float`, *optional*, defaults to 1e-5):
+ A small value added to the denominator for numerical stability in normalization layers.
+ scaling_factor (`float`, *optional*, defaults to 1.0):
+ A scaling factor applied to certain parameters or layers in the model. This can be used for adjusting the
+ overall scale of the model's operations.
+ """
+
+ _supports_gradient_checkpointing = True
+ _no_split_modules = ["Lumina2TransformerBlock"]
+ _skip_layerwise_casting_patterns = ["x_embedder", "norm"]
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: int = 128,
+ patch_size: int = 2,
+ in_channels: int = 16,
+ out_channels: Optional[int] = None,
+ hidden_size: int = 2304,
+ num_layers: int = 26,
+ num_refiner_layers: int = 2,
+ num_attention_heads: int = 24,
+ num_kv_heads: int = 8,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ scaling_factor: float = 1.0,
+ axes_dim_rope: Tuple[int, int, int] = (32, 32, 32),
+ axes_lens: Tuple[int, int, int] = (300, 512, 512),
+ cap_feat_dim: int = 1024,
+ ) -> None:
+ super().__init__()
+ self.out_channels = out_channels or in_channels
+
+ # 1. Positional, patch & conditional embeddings
+ self.rope_embedder = Lumina2RotaryPosEmbed(
+ theta=10000, axes_dim=axes_dim_rope, axes_lens=axes_lens, patch_size=patch_size
+ )
+
+ self.x_embedder = nn.Linear(in_features=patch_size * patch_size * in_channels, out_features=hidden_size)
+
+ self.time_caption_embed = Lumina2CombinedTimestepCaptionEmbedding(
+ hidden_size=hidden_size, cap_feat_dim=cap_feat_dim, norm_eps=norm_eps
+ )
+
+ # 2. Noise and context refinement blocks
+ self.noise_refiner = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True,
+ )
+ for _ in range(num_refiner_layers)
+ ]
+ )
+
+ self.context_refiner = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=False,
+ )
+ for _ in range(num_refiner_layers)
+ ]
+ )
+
+ # 3. Transformer blocks
+ self.layers = nn.ModuleList(
+ [
+ Lumina2TransformerBlock(
+ hidden_size,
+ num_attention_heads,
+ num_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ modulation=True,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ # 4. Output norm & projection
+ self.norm_out = LuminaLayerNormContinuous(
+ embedding_dim=hidden_size,
+ conditioning_embedding_dim=min(hidden_size, 1024),
+ elementwise_affine=False,
+ eps=1e-6,
+ bias=True,
+ out_dim=patch_size * patch_size * self.out_channels,
+ )
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ timestep: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ use_mask_in_transformer: bool = True,
+ return_dict: bool = True,
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
+ batch_size = hidden_states.size(0)
+
+ # 1. Condition, positional & patch embedding
+ temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
+
+ (
+ hidden_states,
+ hidden_mask,
+ hidden_sizes,
+ encoder_hidden_len,
+ hidden_len,
+ joint_rotary_emb,
+ encoder_rotary_emb,
+ hidden_rotary_emb,
+ max_seq_len,
+ ) = self.rope_embedder(hidden_states, attention_mask)
+
+ hidden_states = self.x_embedder(hidden_states)
+
+ # 2. Context & noise refinement
+ for layer in self.context_refiner:
+ # NOTE: mask not used for performance
+ encoder_hidden_states = layer(
+ encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
+ )
+
+ for layer in self.noise_refiner:
+ # NOTE: mask not used for performance
+ hidden_states = layer(
+ hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb
+ )
+
+ # 3. Attention mask preparation
+ mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
+ padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
+ for i in range(batch_size):
+ cap_len = encoder_hidden_len[i]
+ img_len = hidden_len[i]
+ mask[i, : cap_len + img_len] = True
+ padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
+ padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
+ hidden_states = padded_hidden_states
+
+ # 4. Transformer blocks
+ for layer in self.layers:
+ # NOTE: mask not used for performance
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
+ hidden_states = self._gradient_checkpointing_func(
+ layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb
+ )
+ else:
+ hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb)
+
+ # 5. Output norm & projection & unpatchify
+ hidden_states = self.norm_out(hidden_states, temb)
+
+ height_tokens = width_tokens = self.config.patch_size
+ output = []
+ for i in range(len(hidden_sizes)):
+ height, width = hidden_sizes[i]
+ begin = encoder_hidden_len[i]
+ end = begin + (height // height_tokens) * (width // width_tokens)
+ output.append(
+ hidden_states[i][begin:end]
+ .view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
+ .permute(4, 0, 2, 1, 3)
+ .flatten(3, 4)
+ .flatten(1, 2)
+ )
+ output = torch.stack(output, dim=0)
+
+ if not return_dict:
+ return (output,)
+ return Transformer2DModelOutput(sample=output)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 5829cf495dcc..e82d92e70e8b 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -256,6 +256,7 @@
_import_structure["latte"] = ["LattePipeline"]
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
+ _import_structure["lumina2"] = ["Lumina2Text2ImgPipeline"]
_import_structure["marigold"].extend(
[
"MarigoldDepthPipeline",
@@ -596,6 +597,7 @@
)
from .ltx import LTXImageToVideoPipeline, LTXPipeline
from .lumina import LuminaText2ImgPipeline
+ from .lumina2 import Lumina2Text2ImgPipeline
from .marigold import (
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index a19329431b05..6066836e7a05 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -65,6 +65,7 @@
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .lumina import LuminaText2ImgPipeline
+from .lumina2 import Lumina2Text2ImgPipeline
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
@@ -135,6 +136,7 @@
("flux-control", FluxControlPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
+ ("lumina2", Lumina2Text2ImgPipeline),
("cogview3", CogView3PlusPipeline),
]
)
diff --git a/src/diffusers/pipelines/lumina2/__init__.py b/src/diffusers/pipelines/lumina2/__init__.py
new file mode 100644
index 000000000000..0e51a768a785
--- /dev/null
+++ b/src/diffusers/pipelines/lumina2/__init__.py
@@ -0,0 +1,48 @@
+from typing import TYPE_CHECKING
+
+from ...utils import (
+ DIFFUSERS_SLOW_IMPORT,
+ OptionalDependencyNotAvailable,
+ _LazyModule,
+ get_objects_from_module,
+ is_torch_available,
+ is_transformers_available,
+)
+
+
+_dummy_objects = {}
+_import_structure = {}
+
+
+try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ from ...utils import dummy_torch_and_transformers_objects # noqa F403
+
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
+else:
+ _import_structure["pipeline_lumina2"] = ["Lumina2Text2ImgPipeline"]
+
+if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
+ try:
+ if not (is_transformers_available() and is_torch_available()):
+ raise OptionalDependencyNotAvailable()
+
+ except OptionalDependencyNotAvailable:
+ from ...utils.dummy_torch_and_transformers_objects import *
+ else:
+ from .pipeline_lumina2 import Lumina2Text2ImgPipeline
+
+else:
+ import sys
+
+ sys.modules[__name__] = _LazyModule(
+ __name__,
+ globals()["__file__"],
+ _import_structure,
+ module_spec=__spec__,
+ )
+
+ for name, value in _dummy_objects.items():
+ setattr(sys.modules[__name__], name, value)
diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
new file mode 100644
index 000000000000..801ed25093a3
--- /dev/null
+++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py
@@ -0,0 +1,770 @@
+# Copyright 2024 Alpha-VLLM and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from ...image_processor import VaeImageProcessor
+from ...models import AutoencoderKL
+from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
+from ...schedulers import FlowMatchEulerDiscreteScheduler
+from ...utils import (
+ is_bs4_available,
+ is_ftfy_available,
+ is_torch_xla_available,
+ logging,
+ replace_example_docstring,
+)
+from ...utils.torch_utils import randn_tensor
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+
+
+if is_torch_xla_available():
+ import torch_xla.core.xla_model as xm
+
+ XLA_AVAILABLE = True
+else:
+ XLA_AVAILABLE = False
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+if is_bs4_available():
+ pass
+
+if is_ftfy_available():
+ pass
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import Lumina2Text2ImgPipeline
+
+ >>> pipe = Lumina2Text2ImgPipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
+ >>> # Enable memory optimizations.
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> prompt = "Upper body of a young woman in a Victorian-era outfit with brass goggles and leather straps. Background shows an industrial revolution cityscape with smoky skies and tall, metal structures"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
+def calculate_shift(
+ image_seq_len,
+ base_seq_len: int = 256,
+ max_seq_len: int = 4096,
+ base_shift: float = 0.5,
+ max_shift: float = 1.16,
+):
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
+ b = base_shift - m * base_seq_len
+ mu = image_seq_len * m + b
+ return mu
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ sigmas: Optional[List[float]] = None,
+ **kwargs,
+):
+ r"""
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
+ must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
+ `num_inference_steps` and `sigmas` must be `None`.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
+ `num_inference_steps` and `timesteps` must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None and sigmas is not None:
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ elif sigmas is not None:
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accept_sigmas:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+
+class Lumina2Text2ImgPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Lumina-T2I.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`AutoModel`]):
+ Frozen text-encoder. Lumina-T2I uses
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel), specifically the
+ [t5-v1_1-xxl](https://huggingface.co/Alpha-VLLM/tree/main/t5-v1_1-xxl) variant.
+ tokenizer (`AutoModel`):
+ Tokenizer of class
+ [AutoModel](https://huggingface.co/docs/transformers/model_doc/t5#transformers.AutoModel).
+ transformer ([`Transformer2DModel`]):
+ A text conditioned `Transformer2DModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
+ """
+
+ _optional_components = []
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
+
+ def __init__(
+ self,
+ transformer: Lumina2Transformer2DModel,
+ scheduler: FlowMatchEulerDiscreteScheduler,
+ vae: AutoencoderKL,
+ text_encoder: AutoModel,
+ tokenizer: AutoTokenizer,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ transformer=transformer,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 8
+ self.default_sample_size = (
+ self.transformer.config.sample_size
+ if hasattr(self, "transformer") and self.transformer is not None
+ else 128
+ )
+ self.default_image_size = self.default_sample_size * self.vae_scale_factor
+ self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts."
+
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
+
+ if getattr(self, "tokenizer", None) is not None:
+ self.tokenizer.padding_side = "right"
+
+ def _get_gemma_prompt_embeds(
+ self,
+ prompt: Union[str, List[str]],
+ device: Optional[torch.device] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ device = device or self._execution_device
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=max_sequence_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ text_input_ids = text_inputs.input_ids.to(device)
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids.to(device)
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because Gemma can only handle sequences up to"
+ f" {max_sequence_length} tokens: {removed_text}"
+ )
+
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
+ prompt_embeds = self.text_encoder(
+ text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
+ )
+ prompt_embeds = prompt_embeds.hidden_states[-2]
+
+ if self.text_encoder is not None:
+ dtype = self.text_encoder.dtype
+ elif self.transformer is not None:
+ dtype = self.transformer.dtype
+ else:
+ dtype = None
+
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
+
+ _, seq_len, _ = prompt_embeds.shape
+
+ return prompt_embeds, prompt_attention_mask
+
+ # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
+ def encode_prompt(
+ self,
+ prompt: Union[str, List[str]],
+ do_classifier_free_guidance: bool = True,
+ negative_prompt: Union[str, List[str]] = None,
+ num_images_per_prompt: int = 1,
+ device: Optional[torch.device] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ system_prompt: Optional[str] = None,
+ max_sequence_length: int = 256,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
+ instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
+ Lumina-T2I, this should be "".
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
+ whether to use classifier free guidance or not
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ number of images that should be generated per prompt
+ device: (`torch.device`, *optional*):
+ torch device to place the resulting embeddings on
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I, it's should be the embeddings of the "" string.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use for the prompt.
+ """
+ if device is None:
+ device = self._execution_device
+
+ prompt = [prompt] if isinstance(prompt, str) else prompt
+ if prompt is not None:
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if system_prompt is None:
+ system_prompt = self.system_prompt
+ if prompt is not None:
+ prompt = [system_prompt + " " + p for p in prompt]
+
+ if prompt_embeds is None:
+ prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ batch_size, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ prompt_attention_mask = prompt_attention_mask.view(batch_size * num_images_per_prompt, -1)
+
+ # Get negative embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ negative_prompt = negative_prompt if negative_prompt is not None else ""
+
+ # Normalize str to list
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+
+ if prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
+ prompt=negative_prompt,
+ device=device,
+ max_sequence_length=max_sequence_length,
+ )
+
+ batch_size, seq_len, _ = negative_prompt_embeds.shape
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(
+ batch_size * num_images_per_prompt, -1
+ )
+
+ return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ prompt_attention_mask=None,
+ negative_prompt_attention_mask=None,
+ callback_on_step_end_tensor_inputs=None,
+ max_sequence_length=None,
+ ):
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and prompt_attention_mask is None:
+ raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
+
+ if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+ if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
+ raise ValueError(
+ "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
+ f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
+ f" {negative_prompt_attention_mask.shape}."
+ )
+
+ if max_sequence_length is not None and max_sequence_length > 512:
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
+ processing larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
+ # VAE applies 8x compression on images but we must also account for packing which requires
+ # latent height and width to be divisible by 2.
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
+
+ shape = (batch_size, num_channels_latents, height, width)
+
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ return latents
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ num_inference_steps: int = 30,
+ guidance_scale: float = 4.0,
+ negative_prompt: Union[str, List[str]] = None,
+ sigmas: List[float] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.Tensor] = None,
+ prompt_embeds: Optional[torch.Tensor] = None,
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
+ prompt_attention_mask: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ system_prompt: Optional[str] = None,
+ cfg_trunc_ratio: float = 1.0,
+ cfg_normalization: bool = True,
+ use_mask_in_transformer: bool = True,
+ max_sequence_length: int = 256,
+ ) -> Union[ImagePipelineOutput, Tuple]:
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ num_inference_steps (`int`, *optional*, defaults to 30):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ sigmas (`List[float]`, *optional*):
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
+ will be used.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size):
+ The width in pixels of the generated image.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.Tensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
+ Pre-generated negative text embeddings. For Lumina-T2I this negative prompt should be "". If not
+ provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
+ negative_prompt_attention_mask (`torch.Tensor`, *optional*):
+ Pre-generated attention mask for negative text embeddings.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeline class.
+ system_prompt (`str`, *optional*):
+ The system prompt to use for the image generation.
+ cfg_trunc_ratio (`float`, *optional*, defaults to `1.0`):
+ The ratio of the timestep interval to apply normalization-based guidance scale.
+ cfg_normalization (`bool`, *optional*, defaults to `True`):
+ Whether to apply normalization-based guidance scale.
+ use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
+ Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
+ max_sequence_length (`int`, defaults to `256`):
+ Maximum sequence length to use with the `prompt`.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
+ returned where the first element is a list with the generated images
+ """
+ height = height or self.default_sample_size * self.vae_scale_factor
+ width = width or self.default_sample_size * self.vae_scale_factor
+ self._guidance_scale = guidance_scale
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ # 3. Encode input prompt
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = self.encode_prompt(
+ prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ max_sequence_length=max_sequence_length,
+ system_prompt=system_prompt,
+ )
+
+ # 4. Prepare latents.
+ latent_channels = self.transformer.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ latent_channels,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 5. Prepare timesteps
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
+ image_seq_len = latents.shape[1]
+ mu = calculate_shift(
+ image_seq_len,
+ self.scheduler.config.get("base_image_seq_len", 256),
+ self.scheduler.config.get("max_image_seq_len", 4096),
+ self.scheduler.config.get("base_shift", 0.5),
+ self.scheduler.config.get("max_shift", 1.15),
+ )
+ timesteps, num_inference_steps = retrieve_timesteps(
+ self.scheduler,
+ num_inference_steps,
+ device,
+ sigmas=sigmas,
+ mu=mu,
+ )
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Denoising loop
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # compute whether apply classifier-free truncation on this timestep
+ do_classifier_free_truncation = (i + 1) / num_inference_steps > cfg_trunc_ratio
+ # reverse the timestep since Lumina uses t=0 as the noise and t=1 as the image
+ current_timestep = 1 - t / self.scheduler.config.num_train_timesteps
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ current_timestep = current_timestep.expand(latents.shape[0])
+
+ noise_pred_cond = self.transformer(
+ hidden_states=latents,
+ timestep=current_timestep,
+ encoder_hidden_states=prompt_embeds,
+ attention_mask=prompt_attention_mask,
+ use_mask_in_transformer=use_mask_in_transformer,
+ return_dict=False,
+ )[0]
+
+ # perform normalization-based guidance scale on a truncated timestep interval
+ if self.do_classifier_free_guidance and not do_classifier_free_truncation:
+ noise_pred_uncond = self.transformer(
+ hidden_states=latents,
+ timestep=current_timestep,
+ encoder_hidden_states=negative_prompt_embeds,
+ attention_mask=negative_prompt_attention_mask,
+ use_mask_in_transformer=use_mask_in_transformer,
+ return_dict=False,
+ )[0]
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
+ # apply normalization after classifier-free guidance
+ if cfg_normalization:
+ cond_norm = torch.norm(noise_pred_cond, dim=-1, keepdim=True)
+ noise_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
+ noise_pred = noise_pred * (cond_norm / noise_norm)
+ else:
+ noise_pred = noise_pred_cond
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_dtype = latents.dtype
+ noise_pred = -noise_pred
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
+
+ if latents.dtype != latents_dtype:
+ if torch.backends.mps.is_available():
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
+ latents = latents.to(latents_dtype)
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+
+ if XLA_AVAILABLE:
+ xm.mark_step()
+
+ if not output_type == "latent":
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
+ image = self.vae.decode(latents, return_dict=False)[0]
+ image = self.image_processor.postprocess(image, output_type=output_type)
+ else:
+ image = latents
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (image,)
+
+ return ImagePipelineOutput(images=image)
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 6a1978944c9f..654c78539f07 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -531,6 +531,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class Lumina2Transformer2DModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class LuminaNextDiT2DModel(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index b899915c3046..19017c86eb93 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1142,6 +1142,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+class Lumina2Text2ImgPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
class LuminaText2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py
new file mode 100644
index 000000000000..e89f160433bd
--- /dev/null
+++ b/tests/models/transformers/test_models_transformer_lumina2.py
@@ -0,0 +1,89 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import torch
+
+from diffusers import Lumina2Transformer2DModel
+from diffusers.utils.testing_utils import (
+ enable_full_determinism,
+ torch_device,
+)
+
+from ..test_modeling_common import ModelTesterMixin
+
+
+enable_full_determinism()
+
+
+class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase):
+ model_class = Lumina2Transformer2DModel
+ main_input_name = "hidden_states"
+ uses_custom_attn_processor = True
+
+ @property
+ def dummy_input(self):
+ batch_size = 2 # N
+ num_channels = 4 # C
+ height = width = 16 # H, W
+ embedding_dim = 32 # D
+ sequence_length = 16 # L
+
+ hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
+ encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
+ timestep = torch.rand(size=(batch_size,)).to(torch_device)
+ attention_mask = torch.ones(size=(batch_size, sequence_length), dtype=torch.bool).to(torch_device)
+
+ return {
+ "hidden_states": hidden_states,
+ "encoder_hidden_states": encoder_hidden_states,
+ "timestep": timestep,
+ "attention_mask": attention_mask,
+ }
+
+ @property
+ def input_shape(self):
+ return (4, 16, 16)
+
+ @property
+ def output_shape(self):
+ return (4, 16, 16)
+
+ def prepare_init_args_and_inputs_for_common(self):
+ init_dict = {
+ "sample_size": 16,
+ "patch_size": 2,
+ "in_channels": 4,
+ "hidden_size": 24,
+ "num_layers": 2,
+ "num_refiner_layers": 1,
+ "num_attention_heads": 3,
+ "num_kv_heads": 1,
+ "multiple_of": 2,
+ "ffn_dim_multiplier": None,
+ "norm_eps": 1e-5,
+ "scaling_factor": 1.0,
+ "axes_dim_rope": (4, 2, 2),
+ "axes_lens": (128, 128, 128),
+ "cap_feat_dim": 32,
+ }
+
+ inputs_dict = self.dummy_input
+ return init_dict, inputs_dict
+
+ def test_gradient_checkpointing_is_applied(self):
+ expected_set = {"Lumina2Transformer2DModel"}
+ super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
diff --git a/tests/pipelines/lumina2/__init__.py b/tests/pipelines/lumina2/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/lumina2/test_pipeline_lumina2.py b/tests/pipelines/lumina2/test_pipeline_lumina2.py
new file mode 100644
index 000000000000..f8e0667ce1d2
--- /dev/null
+++ b/tests/pipelines/lumina2/test_pipeline_lumina2.py
@@ -0,0 +1,147 @@
+import unittest
+
+import numpy as np
+import torch
+from transformers import AutoTokenizer, GemmaConfig, GemmaForCausalLM
+
+from diffusers import (
+ AutoencoderKL,
+ FlowMatchEulerDiscreteScheduler,
+ Lumina2Text2ImgPipeline,
+ Lumina2Transformer2DModel,
+)
+from diffusers.utils.testing_utils import torch_device
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+class Lumina2Text2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterMixin):
+ pipeline_class = Lumina2Text2ImgPipeline
+ params = frozenset(
+ [
+ "prompt",
+ "height",
+ "width",
+ "guidance_scale",
+ "negative_prompt",
+ "prompt_embeds",
+ "negative_prompt_embeds",
+ ]
+ )
+ batch_params = frozenset(["prompt", "negative_prompt"])
+ required_optional_params = frozenset(
+ [
+ "num_inference_steps",
+ "generator",
+ "latents",
+ "return_dict",
+ "callback_on_step_end",
+ "callback_on_step_end_tensor_inputs",
+ ]
+ )
+
+ supports_dduf = False
+ test_xformers_attention = False
+ test_layerwise_casting = True
+
+ def get_dummy_components(self):
+ torch.manual_seed(0)
+ transformer = Lumina2Transformer2DModel(
+ sample_size=4,
+ patch_size=2,
+ in_channels=4,
+ hidden_size=8,
+ num_layers=2,
+ num_attention_heads=1,
+ num_kv_heads=1,
+ multiple_of=16,
+ ffn_dim_multiplier=None,
+ norm_eps=1e-5,
+ scaling_factor=1.0,
+ axes_dim_rope=[4, 2, 2],
+ cap_feat_dim=8,
+ )
+
+ torch.manual_seed(0)
+ vae = AutoencoderKL(
+ sample_size=32,
+ in_channels=3,
+ out_channels=3,
+ block_out_channels=(4,),
+ layers_per_block=1,
+ latent_channels=4,
+ norm_num_groups=1,
+ use_quant_conv=False,
+ use_post_quant_conv=False,
+ shift_factor=0.0609,
+ scaling_factor=1.5035,
+ )
+
+ scheduler = FlowMatchEulerDiscreteScheduler()
+ tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
+
+ torch.manual_seed(0)
+ config = GemmaConfig(
+ head_dim=2,
+ hidden_size=8,
+ intermediate_size=37,
+ num_attention_heads=4,
+ num_hidden_layers=2,
+ num_key_value_heads=4,
+ )
+ text_encoder = GemmaForCausalLM(config)
+
+ components = {
+ "transformer": transformer.eval(),
+ "vae": vae.eval(),
+ "scheduler": scheduler,
+ "text_encoder": text_encoder.eval(),
+ "tokenizer": tokenizer,
+ }
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device="cpu").manual_seed(seed)
+
+ inputs = {
+ "prompt": "A painting of a squirrel eating a burger",
+ "generator": generator,
+ "num_inference_steps": 2,
+ "guidance_scale": 5.0,
+ "height": 32,
+ "width": 32,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_lumina_prompt_embeds(self):
+ pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
+ inputs = self.get_dummy_inputs(torch_device)
+
+ output_with_prompt = pipe(**inputs).images[0]
+
+ inputs = self.get_dummy_inputs(torch_device)
+ prompt = inputs.pop("prompt")
+
+ do_classifier_free_guidance = inputs["guidance_scale"] > 1
+ (
+ prompt_embeds,
+ prompt_attention_mask,
+ negative_prompt_embeds,
+ negative_prompt_attention_mask,
+ ) = pipe.encode_prompt(
+ prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ device=torch_device,
+ )
+ output_with_embeds = pipe(
+ prompt_embeds=prompt_embeds,
+ prompt_attention_mask=prompt_attention_mask,
+ **inputs,
+ ).images[0]
+
+ max_diff = np.abs(output_with_prompt - output_with_embeds).max()
+ assert max_diff < 1e-4