Skip to content

Commit 319176e

Browse files
yoshphysclaude
andcommitted
Fix black formatting in irodori_tts model files
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e41ba05 commit 319176e

File tree

3 files changed

+88
-43
lines changed

3 files changed

+88
-43
lines changed

mlx_audio/tts/models/irodori_tts/config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,19 @@ def speaker_patched_latent_dim(self) -> int:
4949

5050
@property
5151
def text_mlp_ratio_resolved(self) -> float:
52-
return self.mlp_ratio if self.text_mlp_ratio is None else float(self.text_mlp_ratio)
52+
return (
53+
self.mlp_ratio
54+
if self.text_mlp_ratio is None
55+
else float(self.text_mlp_ratio)
56+
)
5357

5458
@property
5559
def speaker_mlp_ratio_resolved(self) -> float:
56-
return self.mlp_ratio if self.speaker_mlp_ratio is None else float(self.speaker_mlp_ratio)
60+
return (
61+
self.mlp_ratio
62+
if self.speaker_mlp_ratio is None
63+
else float(self.speaker_mlp_ratio)
64+
)
5765

5866

5967
@dataclass

mlx_audio/tts/models/irodori_tts/irodori_tts.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,14 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
9999
dac = DACVAE(cfg)
100100
dac.load_weights(str(local_dacvae / "model.safetensors"))
101101
import mlx.core as _mx
102+
102103
_mx.eval(dac.parameters())
103104
model.dacvae = dac
104105
else:
105106
model.dacvae = DACVAE.from_pretrained(model.config.dacvae_repo)
106107
except Exception as e:
107108
import warnings
109+
108110
warnings.warn(
109111
f"Could not load DACVAE: {e}\n"
110112
"Set model.dacvae manually before calling generate()."
@@ -119,6 +121,7 @@ def post_load_hook(cls, model: "Model", model_path: Path) -> "Model":
119121
def _get_tokenizer(self):
120122
if self._tokenizer is None:
121123
from transformers import AutoTokenizer
124+
122125
self._tokenizer = AutoTokenizer.from_pretrained(
123126
self.config.dit.text_tokenizer_repo
124127
)
@@ -147,23 +150,23 @@ def _prepare_text(
147150
# Reference audio encoding
148151
# ------------------------------------------------------------------
149152

150-
def _encode_ref_audio(
151-
self, audio: mx.array
152-
) -> tuple[mx.array, mx.array]:
153+
def _encode_ref_audio(self, audio: mx.array) -> tuple[mx.array, mx.array]:
153154
"""
154155
Encode reference waveform with DACVAE.
155156
audio: (1, samples) at config.sample_rate
156157
Returns (latent, mask): latent (1, T, 128), mask (1, T) bool
157158
"""
158159
assert self.dacvae is not None, "DACVAE not loaded"
159160

160-
max_samples = self.config.max_speaker_latent_length * self.config.audio_downsample_factor
161+
max_samples = (
162+
self.config.max_speaker_latent_length * self.config.audio_downsample_factor
163+
)
161164
audio = audio[:, :max_samples]
162165

163166
# DACVAE encode expects (B, L, 1)
164-
audio_in = audio[:, :, None] # (1, L, 1)
165-
latent = self.dacvae.encode(audio_in) # (1, 128, T) channels-first
166-
latent = mx.transpose(latent, (0, 2, 1)) # (1, T, 128) sequence-first
167+
audio_in = audio[:, :, None] # (1, L, 1)
168+
latent = self.dacvae.encode(audio_in) # (1, 128, T) channels-first
169+
latent = mx.transpose(latent, (0, 2, 1)) # (1, T, 128) sequence-first
167170

168171
actual_t = int(audio.shape[1]) // self.config.audio_downsample_factor
169172
actual_t = min(actual_t, latent.shape[1])
@@ -266,8 +269,8 @@ def generate(
266269
# Decode latent → waveform
267270
# latent_out: (1, T, 128)
268271
latent_for_decode = mx.transpose(latent_out, (0, 2, 1)) # (1, 128, T)
269-
audio_out = self.dacvae.decode(latent_for_decode) # (1, L, 1)
270-
audio_out = audio_out[:, :, 0] # (1, L)
272+
audio_out = self.dacvae.decode(latent_for_decode) # (1, L, 1)
273+
audio_out = audio_out[:, :, 0] # (1, L)
271274

272275
# Trim trailing silence
273276
silence_t = _find_silence_point(latent_out[0])
@@ -277,7 +280,9 @@ def generate(
277280
audio = audio_out[0] # (L,)
278281
samples = int(audio.shape[0])
279282
elapsed = max(time.perf_counter() - start_time, 1e-6)
280-
audio_duration_seconds = samples / self.sample_rate if self.sample_rate > 0 else 0.0
283+
audio_duration_seconds = (
284+
samples / self.sample_rate if self.sample_rate > 0 else 0.0
285+
)
281286

282287
h = int(audio_duration_seconds // 3600)
283288
m = int((audio_duration_seconds % 3600) // 60)

mlx_audio/tts/models/irodori_tts/sampling.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def _temporal_score_rescale(
5858
if t >= 1.0:
5959
return v_pred
6060
one_minus_t = 1.0 - t
61-
snr = (one_minus_t ** 2) / (t ** 2)
62-
sigma_sq = rescale_sigma ** 2
61+
snr = (one_minus_t**2) / (t**2)
62+
sigma_sq = rescale_sigma**2
6363
ratio = (snr * sigma_sq + 1.0) / (snr * sigma_sq / rescale_k + 1.0)
6464
return (ratio * (one_minus_t * v_pred + x_t) - x_t) / one_minus_t
6565

@@ -237,13 +237,19 @@ def sample_euler_cfg(
237237
axis=0,
238238
)
239239
v_out = model.forward_with_conditions(
240-
x_t=x_cfg, t=t_cfg,
240+
x_t=x_cfg,
241+
t=t_cfg,
241242
text_state=mx.concatenate(
242-
[text_state_cond, text_state_uncond, text_state_cond], axis=0
243+
[text_state_cond, text_state_uncond, text_state_cond],
244+
axis=0,
243245
),
244246
text_mask=text_mask_cfg,
245247
speaker_state=mx.concatenate(
246-
[speaker_state_cond, speaker_state_cond, speaker_state_uncond],
248+
[
249+
speaker_state_cond,
250+
speaker_state_cond,
251+
speaker_state_uncond,
252+
],
247253
axis=0,
248254
),
249255
speaker_mask=speaker_mask_cfg,
@@ -261,7 +267,8 @@ def sample_euler_cfg(
261267
x_cfg = mx.concatenate([x_t, x_t], axis=0)
262268
t_cfg = mx.full((batch_size * 2,), t, dtype=mx.float32)
263269
v_out = model.forward_with_conditions(
264-
x_t=x_cfg, t=t_cfg,
270+
x_t=x_cfg,
271+
t=t_cfg,
265272
text_state=mx.concatenate(
266273
[text_state_cond, text_state_uncond], axis=0
267274
),
@@ -284,7 +291,8 @@ def sample_euler_cfg(
284291
x_cfg = mx.concatenate([x_t, x_t], axis=0)
285292
t_cfg = mx.full((batch_size * 2,), t, dtype=mx.float32)
286293
v_out = model.forward_with_conditions(
287-
x_t=x_cfg, t=t_cfg,
294+
x_t=x_cfg,
295+
t=t_cfg,
288296
text_state=mx.concatenate(
289297
[text_state_cond, text_state_cond], axis=0
290298
),
@@ -315,53 +323,77 @@ def sample_euler_cfg(
315323
joint_scale = cfg_scale_text if has_text_cfg else cfg_scale_speaker
316324

317325
v_cond = model.forward_with_conditions(
318-
x_t=x_t, t=t_arr,
319-
text_state=text_state_cond, text_mask=text_mask_cond,
320-
speaker_state=speaker_state_cond, speaker_mask=speaker_mask_cond,
321-
kv_text=kv_text_cond, kv_speaker=kv_speaker_cond,
326+
x_t=x_t,
327+
t=t_arr,
328+
text_state=text_state_cond,
329+
text_mask=text_mask_cond,
330+
speaker_state=speaker_state_cond,
331+
speaker_mask=speaker_mask_cond,
332+
kv_text=kv_text_cond,
333+
kv_speaker=kv_speaker_cond,
322334
)
323335
v_uncond = model.forward_with_conditions(
324-
x_t=x_t, t=t_arr,
325-
text_state=text_state_uncond, text_mask=text_mask_uncond,
326-
speaker_state=speaker_state_uncond, speaker_mask=speaker_mask_uncond,
327-
kv_text=kv_text_uncond_joint, kv_speaker=kv_speaker_uncond_joint,
336+
x_t=x_t,
337+
t=t_arr,
338+
text_state=text_state_uncond,
339+
text_mask=text_mask_uncond,
340+
speaker_state=speaker_state_uncond,
341+
speaker_mask=speaker_mask_uncond,
342+
kv_text=kv_text_uncond_joint,
343+
kv_speaker=kv_speaker_uncond_joint,
328344
)
329345
v_pred = v_cond + joint_scale * (v_cond - v_uncond)
330346

331347
else: # alternating
332348
v_cond = model.forward_with_conditions(
333-
x_t=x_t, t=t_arr,
334-
text_state=text_state_cond, text_mask=text_mask_cond,
335-
speaker_state=speaker_state_cond, speaker_mask=speaker_mask_cond,
336-
kv_text=kv_text_cond, kv_speaker=kv_speaker_cond,
349+
x_t=x_t,
350+
t=t_arr,
351+
text_state=text_state_cond,
352+
text_mask=text_mask_cond,
353+
speaker_state=speaker_state_cond,
354+
speaker_mask=speaker_mask_cond,
355+
kv_text=kv_text_cond,
356+
kv_speaker=kv_speaker_cond,
337357
)
338358
use_text_uncond = (has_text_cfg and has_speaker_cfg and i % 2 == 0) or (
339359
has_text_cfg and not has_speaker_cfg
340360
)
341361
if use_text_uncond:
342362
v_uncond = model.forward_with_conditions(
343-
x_t=x_t, t=t_arr,
344-
text_state=text_state_uncond, text_mask=text_mask_uncond,
345-
speaker_state=speaker_state_cond, speaker_mask=speaker_mask_cond,
346-
kv_text=kv_text_uncond_alt, kv_speaker=kv_speaker_cond,
363+
x_t=x_t,
364+
t=t_arr,
365+
text_state=text_state_uncond,
366+
text_mask=text_mask_uncond,
367+
speaker_state=speaker_state_cond,
368+
speaker_mask=speaker_mask_cond,
369+
kv_text=kv_text_uncond_alt,
370+
kv_speaker=kv_speaker_cond,
347371
)
348372
v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond)
349373
else:
350374
v_uncond = model.forward_with_conditions(
351-
x_t=x_t, t=t_arr,
352-
text_state=text_state_cond, text_mask=text_mask_cond,
353-
speaker_state=speaker_state_uncond, speaker_mask=speaker_mask_uncond,
354-
kv_text=kv_text_cond, kv_speaker=kv_speaker_uncond_alt,
375+
x_t=x_t,
376+
t=t_arr,
377+
text_state=text_state_cond,
378+
text_mask=text_mask_cond,
379+
speaker_state=speaker_state_uncond,
380+
speaker_mask=speaker_mask_uncond,
381+
kv_text=kv_text_cond,
382+
kv_speaker=kv_speaker_uncond_alt,
355383
)
356384
v_pred = v_cond + cfg_scale_speaker * (v_cond - v_uncond)
357385

358386
else:
359387
# no CFG this step
360388
v_pred = model.forward_with_conditions(
361-
x_t=x_t, t=t_arr,
362-
text_state=text_state_cond, text_mask=text_mask_cond,
363-
speaker_state=speaker_state_cond, speaker_mask=speaker_mask_cond,
364-
kv_text=kv_text_cond, kv_speaker=kv_speaker_cond,
389+
x_t=x_t,
390+
t=t_arr,
391+
text_state=text_state_cond,
392+
text_mask=text_mask_cond,
393+
speaker_state=speaker_state_cond,
394+
speaker_mask=speaker_mask_cond,
395+
kv_text=kv_text_cond,
396+
kv_speaker=kv_speaker_cond,
365397
)
366398

367399
# optional temporal score rescaling

0 commit comments

Comments
 (0)