Skip to content

Commit e92ee28

Browse files
committed
update
1 parent 40cf52f commit e92ee28

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

src/diffusers/models/activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class SwiGLU(nn.Module):
137137
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
138138
super().__init__()
139139
self.flip_gate = flip_gate
140-
140+
141141
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
142142
self.activation = nn.SiLU()
143143

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def __init__(
249249
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
250250

251251
self.gradient_checkpointing = False
252-
252+
253253
def _set_gradient_checkpointing(self, module, value=False):
254254
if hasattr(module, "gradient_checkpointing"):
255255
module.gradient_checkpointing = value
@@ -287,7 +287,7 @@ def forward(
287287

288288
for i, block in enumerate(self.transformer_blocks):
289289
if self.gradient_checkpointing:
290-
290+
291291
def create_custom_forward(module):
292292
def custom_forward(*inputs):
293293
return module(*inputs)

src/diffusers/models/transformers/transformer_mochi_original.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def prepare_qkv(
619619
q_x, k_x, v_x = qkv_x.unbind(0) # (B, N, local_h, head_dim)
620620
q_x = self.q_norm_x(q_x)
621621
k_x = self.k_norm_x(k_x)
622-
622+
623623
q_x = apply_rotary_emb_qk_real(q_x, rope_cos, rope_sin)
624624
k_x = apply_rotary_emb_qk_real(k_x, rope_cos, rope_sin)
625625

@@ -665,7 +665,7 @@ def run_attention(
665665
q = q.permute(1, 0, 2).unsqueeze(0)
666666
k = k.permute(1, 0, 2).unsqueeze(0)
667667
v = v.permute(1, 0, 2).unsqueeze(0)
668-
668+
669669
out = F.scaled_dot_product_attention(q, k, v)
670670

671671
out = out.transpose(1, 2).flatten(2, 3)

0 commit comments

Comments
 (0)