Skip to content

Commit 7d8698f

Browse files
yoshphysclaude
andcommitted
Fix black formatting in model.py and text.py
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 319176e commit 7d8698f

File tree

2 files changed

+37
-17
lines changed

2 files changed

+37
-17
lines changed

mlx_audio/tts/models/irodori_tts/model.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,7 @@ def __init__(self, model_dim: int, rank: int, eps: float):
113113
self.scale_up = nn.Linear(rank, model_dim, bias=True)
114114
self.gate_up = nn.Linear(rank, model_dim, bias=True)
115115

116-
def __call__(
117-
self, x: mx.array, cond_embed: mx.array
118-
) -> Tuple[mx.array, mx.array]:
116+
def __call__(self, x: mx.array, cond_embed: mx.array) -> Tuple[mx.array, mx.array]:
119117
shift, scale, gate = mx.split(cond_embed, 3, axis=-1)
120118
shift = self.shift_up(self.shift_down(nn.silu(shift))) + shift
121119
scale = self.scale_up(self.scale_down(nn.silu(scale))) + scale
@@ -324,7 +322,9 @@ def __init__(self, dim: int, heads: int, mlp_hidden_dim: int, norm_eps: float):
324322
def __call__(
325323
self, x: mx.array, mask: Optional[mx.array], freqs_cis: RotaryCache
326324
) -> mx.array:
327-
x = x + self.attention(self.attention_norm(x), key_mask=mask, freqs_cis=freqs_cis)
325+
x = x + self.attention(
326+
self.attention_norm(x), key_mask=mask, freqs_cis=freqs_cis
327+
)
328328
x = x + self.mlp(self.mlp_norm(x))
329329
return x
330330

@@ -394,9 +394,7 @@ def __init__(
394394
TextBlock(dim, heads, mlp_hidden, norm_eps) for _ in range(num_layers)
395395
]
396396

397-
def __call__(
398-
self, latent: mx.array, mask: Optional[mx.array] = None
399-
) -> mx.array:
397+
def __call__(self, latent: mx.array, mask: Optional[mx.array] = None) -> mx.array:
400398
x = self.in_proj(latent) / 6.0
401399
freqs_cis = precompute_freqs_cis(self.head_dim, x.shape[1])
402400
if mask is not None:
@@ -453,8 +451,13 @@ def __call__(
453451
) -> mx.array:
454452
x_norm, attn_gate = self.attention_adaln(x, cond_embed)
455453
x = x + attn_gate * self.attention(
456-
x_norm, text_mask, speaker_mask, freqs_cis,
457-
kv_cache_text, kv_cache_speaker, start_pos,
454+
x_norm,
455+
text_mask,
456+
speaker_mask,
457+
freqs_cis,
458+
kv_cache_text,
459+
kv_cache_speaker,
460+
start_pos,
458461
)
459462
x_norm, mlp_gate = self.mlp_adaln(x, cond_embed)
460463
x = x + mlp_gate * self.mlp(x_norm)
@@ -554,7 +557,9 @@ def build_kv_cache(
554557
speaker_state: mx.array,
555558
) -> Tuple[List[KVCache], List[KVCache]]:
556559
"""Pre-compute per-layer text/speaker KV projections for fast sampling."""
557-
kv_text = [block.attention.get_kv_cache_text(text_state) for block in self.blocks]
560+
kv_text = [
561+
block.attention.get_kv_cache_text(text_state) for block in self.blocks
562+
]
558563
kv_speaker = [
559564
block.attention.get_kv_cache_speaker(speaker_state) for block in self.blocks
560565
]
@@ -576,7 +581,9 @@ def forward_with_conditions(
576581
kv_speaker: Optional[List[KVCache]] = None,
577582
start_pos: int = 0,
578583
) -> mx.array:
579-
t_embed = get_timestep_embedding(t, self.cfg.timestep_embed_dim).astype(x_t.dtype)
584+
t_embed = get_timestep_embedding(t, self.cfg.timestep_embed_dim).astype(
585+
x_t.dtype
586+
)
580587
cond_embed = self.cond_module(t_embed)[:, None, :] # (B, 1, 3*model_dim)
581588

582589
x = self.in_proj(x_t)
@@ -594,8 +601,14 @@ def forward_with_conditions(
594601
else block.attention.get_kv_cache_speaker(speaker_state)
595602
)
596603
x = block(
597-
x, cond_embed, text_mask, speaker_mask,
598-
freqs_cis, kv_t, kv_s, start_pos,
604+
x,
605+
cond_embed,
606+
text_mask,
607+
speaker_mask,
608+
freqs_cis,
609+
kv_t,
610+
kv_s,
611+
start_pos,
599612
)
600613

601614
x = self.out_norm(x)

mlx_audio/tts/models/irodori_tts/text.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
_REPLACE_MAP: dict[str, str] = {
1515
r"\t": "",
1616
r"\[n\]": "",
17-
r" ": "", # narrow no-break space (U+202F) / ideographic space handled below
18-
r" ": "", # ideographic space
17+
r" ": "", # narrow no-break space (U+202F) / ideographic space handled below
18+
r" ": "", # ideographic space
1919
r"[;▼♀♂《》≪≫①②③④⑤⑥]": "",
2020
r"[\u02d7\u2010-\u2015\u2043\u2212\u23af\u23e4\u2500\u2501\u2e3a\u2e3b]": "",
2121
r"[\uff5e\u301C]": "ー",
@@ -38,7 +38,10 @@
3838

3939
# Fullwidth 0-9 → halfwidth
4040
_FULLWIDTH_DIGITS_TO_HALFWIDTH = str.maketrans(
41-
{chr(full): chr(half) for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A))}
41+
{
42+
chr(full): chr(half)
43+
for full, half in zip(range(0xFF10, 0xFF1A), range(0x30, 0x3A))
44+
}
4245
)
4346

4447
# Halfwidth katakana → fullwidth katakana
@@ -67,7 +70,11 @@ def normalize_text(text: str) -> str:
6770

6871
# Strip surrounding bracket pairs
6972
for open_br, close_br in [
70-
("「", "」"), ("『", "』"), ("(", ")"), ("【", "】"), ("(", ")")
73+
("「", "」"),
74+
("『", "』"),
75+
("(", ")"),
76+
("【", "】"),
77+
("(", ")"),
7178
]:
7279
if text.startswith(open_br) and text.endswith(close_br):
7380
text = text[1:-1]

0 commit comments

Comments
 (0)