Skip to content

Commit 98a4554

Browse files
committed
update
1 parent 1e9bc91 commit 98a4554

File tree

3 files changed

+26
-234
lines changed

3 files changed

+26
-234
lines changed

scripts/convert_mochi_to_diffusers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,20 +99,20 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
9999
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
100100
q, k, v = qkv_weight.chunk(3, dim=0)
101101

102-
new_state_dict[block_prefix + "attn1.to_context_q.weight"] = q
103-
new_state_dict[block_prefix + "attn1.to_context_k.weight"] = k
104-
new_state_dict[block_prefix + "attn1.to_context_v.weight"] = v
105-
new_state_dict[block_prefix + "attn1.norm_context_q.weight"] = original_state_dict.pop(
102+
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
103+
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
104+
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
105+
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
106106
old_prefix + "attn.q_norm_y.weight"
107107
)
108-
new_state_dict[block_prefix + "attn1.norm_context_k.weight"] = original_state_dict.pop(
108+
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
109109
old_prefix + "attn.k_norm_y.weight"
110110
)
111111
if i < num_layers - 1:
112-
new_state_dict[block_prefix + "attn1.to_context_out.0.weight"] = original_state_dict.pop(
112+
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
113113
old_prefix + "attn.proj_y.weight"
114114
)
115-
new_state_dict[block_prefix + "attn1.to_context_out.0.bias"] = original_state_dict.pop(
115+
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
116116
old_prefix + "attn.proj_y.bias"
117117
)
118118

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 216 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
_from_deprecated_attn_block: bool = False,
121121
processor: Optional["AttnProcessor"] = None,
122122
out_dim: int = None,
123+
out_context_dim: int = None,
123124
context_pre_only=None,
124125
pre_only=False,
125126
elementwise_affine: bool = True,
@@ -142,6 +143,7 @@ def __init__(
142143
self.dropout = dropout
143144
self.fused_projections = False
144145
self.out_dim = out_dim if out_dim is not None else query_dim
146+
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
145147
self.context_pre_only = context_pre_only
146148
self.pre_only = pre_only
147149

@@ -241,7 +243,7 @@ def __init__(
241243
self.to_out.append(nn.Dropout(dropout))
242244

243245
if self.context_pre_only is not None and not self.context_pre_only:
244-
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
246+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
245247

246248
if qk_norm is not None and added_kv_proj_dim is not None:
247249
if qk_norm == "fp32_layer_norm":
@@ -717,221 +719,6 @@ def fuse_projections(self, fuse=True):
717719
self.fused_projections = fuse
718720

719721

720-
class AsymmetricAttention(nn.Module):
721-
def __init__(
722-
self,
723-
query_dim: int,
724-
query_context_dim: int,
725-
num_attention_heads: int = 8,
726-
attention_head_dim: int = 64,
727-
bias: bool = False,
728-
context_bias: bool = False,
729-
out_dim: Optional[int] = None,
730-
out_context_dim: Optional[int] = None,
731-
qk_norm: Optional[str] = None,
732-
eps: float = 1e-5,
733-
elementwise_affine: bool = True,
734-
processor: Optional["AttnProcessor"] = None,
735-
) -> None:
736-
super().__init__()
737-
738-
from .normalization import RMSNorm
739-
740-
self.query_dim = query_dim
741-
self.query_context_dim = query_context_dim
742-
self.inner_dim = out_dim if out_dim is not None else num_attention_heads * attention_head_dim
743-
self.out_dim = out_dim if out_dim is not None else query_dim
744-
745-
self.scale = attention_head_dim ** -0.5
746-
self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attention_heads
747-
748-
if qk_norm is None:
749-
self.norm_q = None
750-
self.norm_k = None
751-
self.norm_context_q = None
752-
self.norm_context_k = None
753-
elif qk_norm == "rms_norm":
754-
self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
755-
self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
756-
self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
757-
self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
758-
else:
759-
raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`."))
760-
761-
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
762-
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
763-
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
764-
765-
self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
766-
self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
767-
self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
768-
769-
# TODO(aryan): Take care of dropouts for training purpose in future
770-
self.to_out = nn.ModuleList([
771-
nn.Linear(self.inner_dim, self.out_dim)
772-
])
773-
774-
if out_context_dim is not None:
775-
self.to_context_out = nn.ModuleList([
776-
nn.Linear(self.inner_dim, out_context_dim)
777-
])
778-
else:
779-
self.to_context_out = nn.ModuleList([
780-
nn.Identity()
781-
])
782-
783-
if processor is None:
784-
processor = AsymmetricAttnProcessor2_0()
785-
786-
self.set_processor(processor)
787-
788-
def set_processor(self, processor: "AttnProcessor") -> None:
789-
r"""
790-
Set the attention processor to use.
791-
792-
Args:
793-
processor (`AttnProcessor`):
794-
The attention processor to use.
795-
"""
796-
# if current processor is in `self._modules` and if passed `processor` is not, we need to
797-
# pop `processor` from `self._modules`
798-
if (
799-
hasattr(self, "processor")
800-
and isinstance(self.processor, torch.nn.Module)
801-
and not isinstance(processor, torch.nn.Module)
802-
):
803-
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
804-
self._modules.pop("processor")
805-
806-
self.processor = processor
807-
808-
def get_processor(self) -> "AttentionProcessor":
809-
r"""
810-
Get the attention processor in use.
811-
812-
Returns:
813-
"AttentionProcessor": The attention processor in use.
814-
"""
815-
return self.processor
816-
817-
def forward(
818-
self,
819-
hidden_states: torch.Tensor,
820-
encoder_hidden_states: Optional[torch.Tensor] = None,
821-
attention_mask: Optional[torch.Tensor] = None,
822-
**cross_attention_kwargs,
823-
) -> torch.Tensor:
824-
r"""
825-
The forward method of the `Attention` class.
826-
827-
Args:
828-
hidden_states (`torch.Tensor`):
829-
The hidden states of the query.
830-
encoder_hidden_states (`torch.Tensor`, *optional*):
831-
The hidden states of the encoder.
832-
attention_mask (`torch.Tensor`, *optional*):
833-
The attention mask to use. If `None`, no mask is applied.
834-
**cross_attention_kwargs:
835-
Additional keyword arguments to pass along to the cross attention.
836-
837-
Returns:
838-
`torch.Tensor`: The output of the attention layer.
839-
"""
840-
# The `Attention` class can call different attention processors / attention functions
841-
# here we simply pass along all tensors to the selected processor class
842-
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
843-
844-
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
845-
quiet_attn_parameters = {"ip_adapter_masks"}
846-
unused_kwargs = [
847-
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
848-
]
849-
if len(unused_kwargs) > 0:
850-
logger.warning(
851-
f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
852-
)
853-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
854-
855-
return self.processor(
856-
self,
857-
hidden_states,
858-
encoder_hidden_states=encoder_hidden_states,
859-
attention_mask=attention_mask,
860-
**cross_attention_kwargs,
861-
)
862-
863-
864-
class AsymmetricAttnProcessor2_0:
865-
r"""
866-
Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link).
867-
"""
868-
869-
def __init__(self):
870-
if not hasattr(F, "scaled_dot_product_attention"):
871-
raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
872-
873-
def __call__(
874-
self,
875-
attn: AsymmetricAttention,
876-
hidden_states: torch.Tensor,
877-
encoder_hidden_states: torch.Tensor,
878-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
879-
) -> torch.Tensor:
880-
batch_size = hidden_states.size(0)
881-
query = attn.to_q(hidden_states)
882-
key = attn.to_k(hidden_states)
883-
value = attn.to_v(hidden_states)
884-
885-
query_context = attn.to_context_q(encoder_hidden_states)
886-
key_context = attn.to_context_k(encoder_hidden_states)
887-
value_context = attn.to_context_v(encoder_hidden_states)
888-
889-
inner_dim = key.shape[-1]
890-
head_dim = inner_dim / attn.num_attention_heads
891-
892-
query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
893-
key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
894-
value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
895-
896-
query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
897-
key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
898-
value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
899-
900-
if attn.norm_q is not None:
901-
query = attn.norm_q(query)
902-
if attn.norm_k is not None:
903-
key = attn.norm_k(key)
904-
905-
if attn.norm_context_q is not None:
906-
query_context = attn.norm_context_q(query_context)
907-
if attn.norm_context_k is not None:
908-
key_context = attn.norm_context_k(key_context)
909-
910-
if image_rotary_emb is not None:
911-
from .embeddings import apply_rotary_emb
912-
query = apply_rotary_emb(query, image_rotary_emb)
913-
key = apply_rotary_emb(key, image_rotary_emb)
914-
915-
sequence_length = query.size(1)
916-
context_sequence_length = query_context.size(1)
917-
918-
query = torch.cat([query, query_context], dim=1)
919-
key = torch.cat([key, key_context], dim=1)
920-
value = torch.cat([value, value_context], dim=1)
921-
922-
hidden_states = F.scaled_dot_product_attention(
923-
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
924-
)
925-
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
926-
hidden_states = hidden_states.to(query.dtype)
927-
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1)
928-
929-
hidden_states = attn.to_out[0](hidden_states)
930-
encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states)
931-
932-
return hidden_states, encoder_hidden_states
933-
934-
935722
class AttnProcessor:
936723
r"""
937724
Default processor for performing attention-related computations.

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from ...utils import logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
2424
from ..attention import FeedForward
25-
from ..attention_processor import AsymmetricAttention, AsymmetricAttnProcessor2_0
25+
from ..attention_processor import Attention, FluxAttnProcessor2_0
2626
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
@@ -57,17 +57,21 @@ def __init__(
5757
else:
5858
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
5959

60-
self.attn1 = AsymmetricAttention(
60+
self.attn1 = Attention(
6161
query_dim=dim,
62-
query_context_dim=pooled_projection_dim,
63-
num_attention_heads=num_attention_heads,
64-
attention_head_dim=attention_head_dim,
65-
out_dim=dim,
66-
out_context_dim=None if context_pre_only else pooled_projection_dim,
62+
cross_attention_dim=None,
63+
heads=num_attention_heads,
64+
dim_head=attention_head_dim,
65+
bias=False,
6766
qk_norm=qk_norm,
67+
added_kv_proj_dim=pooled_projection_dim,
68+
added_proj_bias=False,
69+
out_dim=dim,
70+
out_context_dim=pooled_projection_dim,
71+
context_pre_only=context_pre_only,
72+
processor=FluxAttnProcessor2_0(),
6873
eps=1e-6,
6974
elementwise_affine=True,
70-
processor=AsymmetricAttnProcessor2_0(),
7175
)
7276

7377
self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
@@ -93,7 +97,7 @@ def forward(
9397
) -> Tuple[torch.Tensor, torch.Tensor]:
9498
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
9599

96-
if self.context_pre_only:
100+
if not self.context_pre_only:
97101
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
98102
encoder_hidden_states, temb
99103
)
@@ -203,9 +207,11 @@ def forward(
203207
post_patch_height = height // p
204208
post_patch_width = width // p
205209

206-
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask)
210+
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype)
207211

212+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
208213
hidden_states = self.patch_embed(hidden_states)
214+
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
209215

210216
for i, block in enumerate(self.transformer_blocks):
211217
hidden_states, encoder_hidden_states = block(
@@ -216,7 +222,6 @@ def forward(
216222
)
217223

218224
# TODO(aryan): do something with self.pos_frequencies
219-
220225
hidden_states = self.norm_out(hidden_states, temb)
221226
hidden_states = self.proj_out(hidden_states)
222227

0 commit comments

Comments
 (0)