Skip to content

Commit 46f95d5

Browse files
committed
make style
1 parent 2fd2ec4 commit 46f95d5

File tree

3 files changed

+54
-29
lines changed

3 files changed

+54
-29
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3096,9 +3096,6 @@ def __call__(
30963096
attention_mask: Optional[torch.Tensor] = None,
30973097
image_rotary_emb: Optional[torch.Tensor] = None,
30983098
) -> torch.Tensor:
3099-
breakpoint()
3100-
batch_size = hidden_states.size(0)
3101-
31023099
query = attn.to_q(hidden_states)
31033100
key = attn.to_k(hidden_states)
31043101
value = attn.to_v(hidden_states)
@@ -3124,8 +3121,9 @@ def __call__(
31243121
encoder_query = attn.norm_added_q(encoder_query)
31253122
if attn.norm_added_k is not None:
31263123
encoder_key = attn.norm_added_k(encoder_key)
3127-
3124+
31283125
if image_rotary_emb is not None:
3126+
31293127
def apply_rotary_emb(x, freqs_cos, freqs_sin):
31303128
x_even = x[..., 0::2].float()
31313129
x_odd = x[..., 1::2].float()
@@ -3137,9 +3135,13 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
31373135

31383136
query = apply_rotary_emb(query, *image_rotary_emb)
31393137
key = apply_rotary_emb(key, *image_rotary_emb)
3140-
3138+
31413139
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
3142-
encoder_query, encoder_key, encoder_value = encoder_query.transpose(1, 2), encoder_key.transpose(1, 2), encoder_value.transpose(1, 2)
3140+
encoder_query, encoder_key, encoder_value = (
3141+
encoder_query.transpose(1, 2),
3142+
encoder_key.transpose(1, 2),
3143+
encoder_value.transpose(1, 2),
3144+
)
31433145

31443146
sequence_length = query.size(2)
31453147
encoder_sequence_length = encoder_query.size(2)
@@ -3152,7 +3154,9 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
31523154
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
31533155
hidden_states = hidden_states.to(query.dtype)
31543156

3155-
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length), dim=1)
3157+
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
3158+
(sequence_length, encoder_sequence_length), dim=1
3159+
)
31563160

31573161
# linear proj
31583162
hidden_states = attn.to_out[0](hidden_states)

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,23 @@ def forward(
145145
class MochiRoPE(nn.Module):
146146
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
147147
super().__init__()
148-
148+
149149
self.target_area = base_height * base_width
150-
150+
151151
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
152152
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
153153
return (edges[:-1] + edges[1:]) / 2
154-
155-
def _get_positions(self, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
154+
155+
def _get_positions(
156+
self,
157+
num_frames: int,
158+
height: int,
159+
width: int,
160+
device: Optional[torch.device] = None,
161+
dtype: Optional[torch.dtype] = None,
162+
) -> torch.Tensor:
156163
scale = (self.target_area / (height * width)) ** 0.5
157-
164+
158165
t = torch.arange(num_frames, device=device, dtype=dtype)
159166
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
160167
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
@@ -170,7 +177,15 @@ def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
170177
freqs_sin = torch.sin(freqs)
171178
return freqs_cos, freqs_sin
172179

173-
def forward(self, pos_frequencies: torch.Tensor, num_frames: int, height: int, width: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None) -> Tuple[torch.Tensor, torch.Tensor]:
180+
def forward(
181+
self,
182+
pos_frequencies: torch.Tensor,
183+
num_frames: int,
184+
height: int,
185+
width: int,
186+
device: Optional[torch.device] = None,
187+
dtype: Optional[torch.dtype] = None,
188+
) -> Tuple[torch.Tensor, torch.Tensor]:
174189
pos = self._get_positions(num_frames, height, width, device, dtype)
175190
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
176191
return rope_cos, rope_sin
@@ -261,7 +276,14 @@ def forward(
261276
hidden_states = self.patch_embed(hidden_states)
262277
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
263278

264-
image_rotary_emb = self.rope(self.pos_frequencies, num_frames, post_patch_height, post_patch_width, device=hidden_states.device, dtype=torch.float32)
279+
image_rotary_emb = self.rope(
280+
self.pos_frequencies,
281+
num_frames,
282+
post_patch_height,
283+
post_patch_width,
284+
device=hidden_states.device,
285+
dtype=torch.float32,
286+
)
265287

266288
for i, block in enumerate(self.transformer_blocks):
267289
hidden_states, encoder_hidden_states = block(

src/diffusers/models/transformers/transformer_mochi_original.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def compute_mixed_rotation(
9696
num_heads: int
9797
9898
Returns:
99-
freqs_cos: [N, num_heads, num_freqs] - cosine components
100-
freqs_sin: [N, num_heads, num_freqs] - sine components
99+
freqs_cos: [N, num_heads, num_freqs] - cosine components freqs_sin: [N, num_heads, num_freqs] - sine components
101100
"""
102101
with torch.autocast("cuda", enabled=False):
103102
assert freqs.ndim == 3
@@ -470,8 +469,7 @@ def forward(
470469
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
471470
472471
Returns:
473-
x: (B, N, dim) tensor of visual tokens after block
474-
y: (B, L, dim) tensor of text tokens after block
472+
x: (B, N, dim) tensor of visual tokens after block y: (B, L, dim) tensor of text tokens after block
475473
"""
476474
breakpoint()
477475
N = x.size(1)
@@ -651,7 +649,7 @@ def run_attention(
651649
breakpoint()
652650
N = M
653651
local_heads = self.num_heads
654-
local_dim = local_heads * self.head_dim
652+
# local_dim = local_heads * self.head_dim
655653
# with torch.autocast("cuda", enabled=False):
656654
# out: torch.Tensor = flash_attn_varlen_qkvpacked_func(
657655
# qkv,
@@ -696,8 +694,8 @@ def forward(
696694
num_frames: Number of frames in the video. N = num_frames * num_spatial_tokens
697695
698696
Returns:
699-
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention
700-
y: (B, L, dim_y) tensor of text token features after multi-modal attention
697+
x: (B, N, dim_x) tensor of visual tokens after multi-modal attention y: (B, L, dim_y) tensor of text token
698+
features after multi-modal attention
701699
"""
702700
B, L, _ = y.shape
703701
_, M, _ = x.shape
@@ -725,6 +723,7 @@ def forward(
725723
)
726724
return x, y
727725

726+
728727
def apply_rotary_emb_qk_real(
729728
xqk: torch.Tensor,
730729
freqs_cos: torch.Tensor,
@@ -756,10 +755,10 @@ def apply_rotary_emb_qk_real(
756755
# assert out.dtype == torch.bfloat16
757756
return out
758757

758+
759759
class PadSplitXY(torch.autograd.Function):
760760
"""
761-
Merge heads, pad and extract visual and text tokens,
762-
and split along the sequence length.
761+
Merge heads, pad and extract visual and text tokens, and split along the sequence length.
763762
"""
764763

765764
@staticmethod
@@ -778,8 +777,7 @@ def forward(
778777
indices: Valid token indices out of unpacked tensor. Shape: (total,)
779778
780779
Returns:
781-
x: Visual tokens. Shape: (B, N, num_heads * head_dim).
782-
y: Text tokens. Shape: (B, L, num_heads * head_dim).
780+
x: Visual tokens. Shape: (B, N, num_heads * head_dim). y: Text tokens. Shape: (B, L, num_heads * head_dim).
783781
"""
784782
ctx.save_for_backward(indices)
785783
ctx.B, ctx.N, ctx.L = B, N, L
@@ -788,9 +786,7 @@ def forward(
788786
# Pad sequences to (B, N + L, dim).
789787
assert indices.ndim == 1
790788
output = torch.zeros(B * (N + L), D, device=xy.device, dtype=dtype)
791-
indices = indices.unsqueeze(1).expand(
792-
-1, D
793-
) # (total,) -> (total, num_heads * head_dim)
789+
indices = indices.unsqueeze(1).expand(-1, D) # (total,) -> (total, num_heads * head_dim)
794790
output.scatter_(0, indices, xy)
795791
xy = output.view(B, N + L, D)
796792

@@ -801,6 +797,7 @@ def forward(
801797
def pad_and_split_xy(xy, indices, B, N, L, dtype) -> Tuple[torch.Tensor, torch.Tensor]:
802798
return PadSplitXY.apply(xy, indices, B, N, L, dtype)
803799

800+
804801
class UnifyStreams(torch.autograd.Function):
805802
"""Unify visual and text streams."""
806803

@@ -1034,7 +1031,9 @@ def forward(
10341031
Args:
10351032
x: (B, C, T, H, W) tensor of spatial inputs (images or latent representations of images)
10361033
sigma: (B,) tensor of noise standard deviations
1037-
y_feat: List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77, y_feat_dim=2048)
1034+
y_feat:
1035+
List((B, L, y_feat_dim) tensor of caption token features. For SDXL text encoders: L=77,
1036+
y_feat_dim=2048)
10381037
y_mask: List((B, L) boolean tensor indicating which tokens are not padding)
10391038
packed_indices: Dict with keys for Flash Attention. Result of compute_packed_indices.
10401039
"""

0 commit comments

Comments
 (0)