Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"LatteTransformer3DModel",
"LTXVideoTransformer3DModel",
"LuminaNextDiT2DModel",
"Lumina2Transformer2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
Expand Down Expand Up @@ -329,6 +330,7 @@
"LTXImageToVideoPipeline",
"LTXPipeline",
"LuminaText2ImgPipeline",
"Lumina2Text2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
Expand Down Expand Up @@ -622,6 +624,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
Expand Down Expand Up @@ -820,6 +823,7 @@
LTXImageToVideoPipeline,
LTXPipeline,
LuminaText2ImgPipeline,
Lumina2Text2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
_import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
Expand Down Expand Up @@ -139,6 +140,7 @@
LatteTransformer3DModel,
LTXVideoTransformer3DModel,
LuminaNextDiT2DModel,
Lumina2Transformer2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
Expand Down
97 changes: 97 additions & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4192,6 +4192,102 @@ def __call__(
return hidden_states


class Lumina2AttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
used in the Lumina2Transformer2DModel model. It applies a s normalization layer and rotary embedding on query and key vector.
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
query_rotary_emb: Optional[torch.Tensor] = None,
key_rotary_emb: Optional[torch.Tensor] = None,
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from embeddings import apply_rotary_emb

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = hidden_states.shape

# Get Query-Key-Value Pair
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query_dim = query.shape[-1]
inner_dim = key.shape[-1]
head_dim = query_dim // attn.heads
dtype = query.dtype

# Get key-value heads
kv_heads = inner_dim // head_dim

query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, kv_heads, head_dim)
value = value.view(batch_size, -1, kv_heads, head_dim)

# Apply Query-Key Norm if needed
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# Apply RoPE if needed
if query_rotary_emb is not None:
query = apply_rotary_emb(query, query_rotary_emb, use_real=False)
if key_rotary_emb is not None:
key = apply_rotary_emb(key, key_rotary_emb, use_real=False)

query, key = query.to(dtype), key.to(dtype)

# Apply proportional attention if true
if base_sequence_length is not None:
softmax_scale = math.sqrt(math.log(sequence_length, base_sequence_length)) * attn.scale
else:
softmax_scale = attn.scale

# perform Grouped-qurey Attention (GQA)
n_rep = attn.heads // kv_heads
if n_rep >= 1:
key = key.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
value = value.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)

# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.bool().view(batch_size, 1, 1, -1)
attention_mask = attention_mask.expand(-1, attn.heads, sequence_length, -1)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, scale=softmax_scale
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)

return hidden_states


class LuminaAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
Expand Down Expand Up @@ -6183,6 +6279,7 @@ def __call__(
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
Lumina2AttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,
Expand Down
61 changes: 61 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,35 @@ def forward(self, latent):
return (latent + pos_embed).to(latent.dtype)


class Lumina2PosEmbed(nn.Module):
def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300, 512, 512)):
super().__init__()
self.theta = theta
self.axes_dim = axes_dim
self.axes_lens = axes_lens
self.freqs_cis = self.precompute_freqs_cis(axes_dim, axes_lens, theta)

def precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
freqs_cis = []
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
emb = get_1d_rotary_pos_embed(
d,
e,
theta=self.theta,
freqs_dtype=torch.float64,
)
freqs_cis.append(emb)
return freqs_cis

def forward(self, ids: torch.Tensor) -> torch.Tensor:
result = []
for i in range(len(self.axes_dim)):
freqs = self.freqs_cis[i].to(ids.device)
index = ids[:, :, i:i+1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
return torch.cat(result, dim=-1)


class LuminaPatchEmbed(nn.Module):
"""
2D Image to Patch Embedding with support for Lumina-T2X
Expand Down Expand Up @@ -1766,6 +1795,38 @@ def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidde
return conditioning


class Lumina2CombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cap_feat_dim=2048, frequency_embedding_size=256, norm_eps=1e-5):
super().__init__()

from normalization import RMSNorm

self.time_proj = Timesteps(
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
)

self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=min(hidden_size, 1024))

self.caption_embedder = nn.Sequential(
RMSNorm(cap_feat_dim, eps=norm_eps),
nn.Linear(
cap_feat_dim,
hidden_size,
bias=True,
),
)

def forward(self, timestep, caption_feat):
# timestep embedding:
time_freq = self.time_proj(timestep)
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))

# caption condition embedding:
caption_embed = self.caption_embedder(caption_feat)

return time_embed, caption_embed


class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
super().__init__()
Expand Down
28 changes: 17 additions & 11 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,25 +211,31 @@ class LuminaRMSNormZero(nn.Module):
embedding_dim (`int`): The size of each embedding vector.
"""

def __init__(self, embedding_dim: int, norm_eps: float, norm_elementwise_affine: bool):
def __init__(self, embedding_dim: int, norm_eps: float, modulation: bool):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.modulation = modulation
if modulation:
self.silu = nn.SiLU()
self.linear = nn.Linear(
min(embedding_dim, 1024),
4 * embedding_dim,
bias=True,
)
self.norm = RMSNorm(embedding_dim, eps=norm_eps)

def forward(
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# emb = self.emb(timestep, encoder_hidden_states, encoder_mask)
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
if self.modulation:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None])
else:
gate_msa, scale_mlp, gate_mlp = None, None, None
x = self.norm(x)

return x, gate_msa, scale_mlp, gate_mlp

Expand Down
Loading