Skip to content

Commit c916ae5

Browse files
committed
make style
1 parent ba9f13f commit c916ae5

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

scripts/convert_mochi_to_diffusers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
130130
)
131131

132132
# Output layers
133-
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.weight"), dim=0)
133+
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
134+
original_state_dict.pop("final_layer.mod.weight"), dim=0
135+
)
134136
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
135137
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
136138
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,17 @@ def __init__(
8888
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
8989
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
9090

91-
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True)
91+
self.ff = FeedForward(
92+
dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
93+
)
9294
self.ff_context = None
9395
if not context_pre_only:
9496
self.ff_context = FeedForward(
95-
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
97+
pooled_projection_dim,
98+
inner_dim=self.ff_context_inner_dim,
99+
activation_fn=activation_fn,
100+
bias=False,
101+
flip_gate=True,
96102
)
97103

98104
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
@@ -131,7 +137,9 @@ def forward(
131137
) * torch.tanh(enc_gate_msa).unsqueeze(1)
132138
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
133139
context_ff_output = self.ff_context(norm_encoder_hidden_states)
134-
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(1)
140+
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
141+
enc_gate_mlp
142+
).unsqueeze(1)
135143

136144
return hidden_states, encoder_hidden_states
137145

0 commit comments

Comments
 (0)