Skip to content

Commit e488d09

Browse files
committed
update
1 parent 9366c8f commit e488d09

File tree

3 files changed

+1163
-0
lines changed

3 files changed

+1163
-0
lines changed

src/diffusers/models/embeddings.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,90 @@ def shape(x):
14301430
return a[:, 0, :] # cls_token
14311431

14321432

1433+
class MochiAttentionPool(nn.Module):
1434+
def __init__(
1435+
self,
1436+
num_attention_heads: int,
1437+
embed_dim: int,
1438+
output_dim: Optional[int] = None,
1439+
) -> None:
1440+
super().__init__()
1441+
1442+
self.output_dim = output_dim or embed_dim
1443+
self.num_attention_heads = num_attention_heads
1444+
1445+
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
1446+
self.to_q = nn.Linear(embed_dim, embed_dim)
1447+
self.to_out = nn.Linear(embed_dim, self.output_dim)
1448+
1449+
@staticmethod
1450+
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
1451+
"""
1452+
Pool tokens in x using mask.
1453+
1454+
NOTE: We assume x does not require gradients.
1455+
1456+
Args:
1457+
x: (B, L, D) tensor of tokens.
1458+
mask: (B, L) boolean tensor indicating which tokens are not padding.
1459+
1460+
Returns:
1461+
pooled: (B, D) tensor of pooled tokens.
1462+
"""
1463+
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
1464+
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
1465+
mask = mask[:, :, None].to(dtype=x.dtype)
1466+
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
1467+
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
1468+
return pooled
1469+
1470+
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
1471+
r"""
1472+
Args:
1473+
x (`torch.Tensor`):
1474+
Tensor of shape `(B, S, D)` of input tokens.
1475+
mask (`torch.Tensor`):
1476+
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
1477+
1478+
Returns:
1479+
`torch.Tensor`:
1480+
`(B, D)` tensor of pooled tokens.
1481+
"""
1482+
D = x.size(2)
1483+
1484+
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
1485+
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
1486+
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
1487+
1488+
# Average non-padding token features. These will be used as the query.
1489+
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
1490+
1491+
# Concat pooled features to input sequence.
1492+
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
1493+
1494+
# Compute queries, keys, values. Only the mean token is used to create a query.
1495+
kv = self.to_kv(x) # (B, L+1, 2 * D)
1496+
q = self.to_q(x[:, 0]) # (B, D)
1497+
1498+
# Extract heads.
1499+
head_dim = D // self.num_attention_heads
1500+
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
1501+
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
1502+
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
1503+
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
1504+
q = q.unsqueeze(2) # (B, H, 1, head_dim)
1505+
1506+
# Compute attention.
1507+
x = F.scaled_dot_product_attention(
1508+
q, k, v, attn_mask=attn_mask, dropout_p=0.0
1509+
) # (B, H, 1, head_dim)
1510+
1511+
# Concatenate heads and run output.
1512+
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
1513+
x = self.to_out(x)
1514+
return x
1515+
1516+
14331517
def get_fourier_embeds_from_boundingbox(embed_dim, box):
14341518
"""
14351519
Args:
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# Copyright 2024 The Genmo team and The HuggingFace Team.
2+
# All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from typing import Any, Dict, Optional, Tuple, Union
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...utils import logging
23+
from ...utils.torch_utils import maybe_allow_in_graph
24+
from ..attention import Attention, FeedForward
25+
from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps
26+
from ..modeling_outputs import Transformer2DModelOutput
27+
from ..modeling_utils import ModelMixin
28+
from ..normalization import AdaLayerNorm
29+
30+
31+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32+
33+
34+
@maybe_allow_in_graph
35+
class MochiTransformerBlock(nn.Module):
36+
def __init__(
37+
self,
38+
dim: int,
39+
num_attention_heads: int,
40+
attention_head_dim: int,
41+
caption_dim: int,
42+
update_captions: bool = True,
43+
) -> None:
44+
super().__init__()
45+
46+
# TODO: Replace this with norm
47+
self.mod_x = nn.Linear(dim, 4 * dim)
48+
if self.update_y:
49+
self.mod_y = nn.Linear(dim, 4 * caption_dim)
50+
else:
51+
self.mod_y = nn.Linear(dim, caption_dim)
52+
53+
# TODO(aryan): attention class does not look compatible
54+
self.attn1 = Attention(...)
55+
# norms go in attention
56+
# self.q_norm_x = RMSNorm(attention_head_dim)
57+
# self.k_norm_x = RMSNorm(attention_head_dim)
58+
# self.q_norm_y = RMSNorm(attention_head_dim)
59+
# self.k_norm_y = RMSNorm(attention_head_dim)
60+
61+
self.proj_x = nn.Linear(dim, dim)
62+
63+
self.proj_y = nn.Linear(dim, caption_dim) if update_captions else None
64+
65+
def forward(self):
66+
pass
67+
68+
69+
@maybe_allow_in_graph
70+
class MochiTransformer3D(ModelMixin, ConfigMixin):
71+
_supports_gradient_checkpointing = True
72+
73+
@register_to_config
74+
def __init__(
75+
self,
76+
patch_size: int = 2,
77+
num_attention_heads: int = 24,
78+
attention_head_dim: int = 128,
79+
num_layers: int = 48,
80+
caption_dim=1536,
81+
mlp_ratio_x=4.0,
82+
mlp_ratio_y=4.0,
83+
in_channels=12,
84+
qk_norm=True,
85+
qkv_bias=False,
86+
out_bias=True,
87+
timestep_mlp_bias=True,
88+
timestep_scale=1000.0,
89+
text_embed_dim=4096,
90+
max_sequence_length=256,
91+
) -> None:
92+
super().__init__()
93+
94+
inner_dim = num_attention_heads * attention_head_dim
95+
96+
self.patch_embed = PatchEmbed(
97+
patch_size=patch_size,
98+
in_channels=in_channels,
99+
embed_dim=inner_dim,
100+
)
101+
102+
self.caption_embedder = MochiAttentionPool(num_attention_heads=8, embed_dim=text_embed_dim, output_dim=inner_dim)
103+
self.caption_proj = nn.Linear(text_embed_dim, caption_dim)
104+
105+
self.pos_frequencies = nn.Parameter(
106+
torch.empty(3, num_attention_heads, attention_head_dim // 2)
107+
)
108+
109+
self.transformer_blocks = nn.ModuleList([
110+
MochiTransformerBlock(
111+
dim=inner_dim,
112+
num_attention_heads=num_attention_heads,
113+
attention_head_dim=attention_head_dim,
114+
caption_dim=caption_dim,
115+
update_captions=i < num_layers - 1,
116+
)
117+
for i in range(num_layers)
118+
])

0 commit comments

Comments
 (0)