Skip to content

Commit 85c8734

Browse files
committed
fix
1 parent 98a4554 commit 85c8734

File tree

5 files changed

+39
-18
lines changed

5 files changed

+39
-18
lines changed

scripts/convert_mochi_to_diffusers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import torch
55
from accelerate import init_empty_weights
66
from safetensors.torch import load_file
7-
# from transformers import T5EncoderModel, T5Tokenizer
87

8+
# from transformers import T5EncoderModel, T5Tokenizer
99
from diffusers import MochiTransformer3DModel
1010
from diffusers.utils.import_utils import is_accelerate_available
1111

@@ -72,10 +72,12 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
7272
old_prefix + "mod_y.bias"
7373
)
7474
else:
75-
new_state_dict[block_prefix + "norm1_context.weight"] = original_state_dict.pop(
75+
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
7676
old_prefix + "mod_y.weight"
7777
)
78-
new_state_dict[block_prefix + "norm1_context.bias"] = original_state_dict.pop(old_prefix + "mod_y.bias")
78+
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
79+
old_prefix + "mod_y.bias"
80+
)
7981

8082
# Visual attention
8183
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
@@ -158,7 +160,7 @@ def main(args):
158160
raise ValueError(f"Unsupported dtype: {args.dtype}")
159161

160162
transformer = None
161-
vae = None
163+
# vae = None
162164

163165
if args.transformer_checkpoint_path is not None:
164166
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1794,7 +1794,9 @@ def __call__(
17941794
hidden_states = attn.to_out[0](hidden_states)
17951795
# dropout
17961796
hidden_states = attn.to_out[1](hidden_states)
1797-
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1797+
1798+
if hasattr(attn, "to_add_out"):
1799+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
17981800

17991801
return hidden_states, encoder_hidden_states
18001802
else:

src/diffusers/models/embeddings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,12 @@ def forward(self, timestep, caption_feat, caption_mask):
13041304

13051305
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
13061306
def __init__(
1307-
self, embedding_dim: int, pooled_projection_dim: int, text_embed_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
1307+
self,
1308+
embedding_dim: int,
1309+
pooled_projection_dim: int,
1310+
text_embed_dim: int,
1311+
time_embed_dim: int = 256,
1312+
num_attention_heads: int = 8,
13081313
) -> None:
13091314
super().__init__()
13101315

src/diffusers/models/normalization.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -385,20 +385,21 @@ def __init__(
385385
out_dim: Optional[int] = None,
386386
):
387387
super().__init__()
388+
388389
# AdaLN
389390
self.silu = nn.SiLU()
390391
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
392+
391393
if norm_type == "layer_norm":
392394
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
395+
if norm_type == "rms_norm":
396+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
393397
else:
394398
raise ValueError(f"unknown norm_type {norm_type}")
395-
# linear_2
399+
400+
self.linear_2 = None
396401
if out_dim is not None:
397-
self.linear_2 = nn.Linear(
398-
embedding_dim,
399-
out_dim,
400-
bias=bias,
401-
)
402+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
402403

403404
def forward(
404405
self,

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
2727
from ..modeling_outputs import Transformer2DModelOutput
2828
from ..modeling_utils import ModelMixin
29-
from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm
29+
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
3030

3131

3232
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -55,7 +55,14 @@ def __init__(
5555
if not context_pre_only:
5656
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
5757
else:
58-
self.norm1_context = nn.Linear(dim, pooled_projection_dim)
58+
self.norm1_context = LuminaLayerNormContinuous(
59+
embedding_dim=pooled_projection_dim,
60+
conditioning_embedding_dim=dim,
61+
eps=1e-6,
62+
elementwise_affine=False,
63+
norm_type="rms_norm",
64+
out_dim=None,
65+
)
5966

6067
self.attn1 = Attention(
6168
query_dim=dim,
@@ -83,7 +90,9 @@ def __init__(
8390
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
8491
self.ff_context = None
8592
if not context_pre_only:
86-
self.ff_context = FeedForward(pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False)
93+
self.ff_context = FeedForward(
94+
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False
95+
)
8796

8897
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
8998
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
@@ -102,7 +111,7 @@ def forward(
102111
encoder_hidden_states, temb
103112
)
104113
else:
105-
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
114+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
106115

107116
attn_hidden_states, context_attn_hidden_states = self.attn1(
108117
hidden_states=norm_hidden_states,
@@ -112,7 +121,7 @@ def forward(
112121

113122
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
114123
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
115-
124+
116125
if not self.context_pre_only:
117126
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
118127
context_attn_hidden_states
@@ -207,7 +216,9 @@ def forward(
207216
post_patch_height = height // p
208217
post_patch_width = width // p
209218

210-
temb, encoder_hidden_states = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype)
219+
temb, encoder_hidden_states = self.time_embed(
220+
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
221+
)
211222

212223
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
213224
hidden_states = self.patch_embed(hidden_states)

0 commit comments

Comments
 (0)