@@ -58,8 +58,8 @@ def _temporal_score_rescale(
5858 if t >= 1.0 :
5959 return v_pred
6060 one_minus_t = 1.0 - t
61- snr = (one_minus_t ** 2 ) / (t ** 2 )
62- sigma_sq = rescale_sigma ** 2
61+ snr = (one_minus_t ** 2 ) / (t ** 2 )
62+ sigma_sq = rescale_sigma ** 2
6363 ratio = (snr * sigma_sq + 1.0 ) / (snr * sigma_sq / rescale_k + 1.0 )
6464 return (ratio * (one_minus_t * v_pred + x_t ) - x_t ) / one_minus_t
6565
@@ -237,13 +237,19 @@ def sample_euler_cfg(
237237 axis = 0 ,
238238 )
239239 v_out = model .forward_with_conditions (
240- x_t = x_cfg , t = t_cfg ,
240+ x_t = x_cfg ,
241+ t = t_cfg ,
241242 text_state = mx .concatenate (
242- [text_state_cond , text_state_uncond , text_state_cond ], axis = 0
243+ [text_state_cond , text_state_uncond , text_state_cond ],
244+ axis = 0 ,
243245 ),
244246 text_mask = text_mask_cfg ,
245247 speaker_state = mx .concatenate (
246- [speaker_state_cond , speaker_state_cond , speaker_state_uncond ],
248+ [
249+ speaker_state_cond ,
250+ speaker_state_cond ,
251+ speaker_state_uncond ,
252+ ],
247253 axis = 0 ,
248254 ),
249255 speaker_mask = speaker_mask_cfg ,
@@ -261,7 +267,8 @@ def sample_euler_cfg(
261267 x_cfg = mx .concatenate ([x_t , x_t ], axis = 0 )
262268 t_cfg = mx .full ((batch_size * 2 ,), t , dtype = mx .float32 )
263269 v_out = model .forward_with_conditions (
264- x_t = x_cfg , t = t_cfg ,
270+ x_t = x_cfg ,
271+ t = t_cfg ,
265272 text_state = mx .concatenate (
266273 [text_state_cond , text_state_uncond ], axis = 0
267274 ),
@@ -284,7 +291,8 @@ def sample_euler_cfg(
284291 x_cfg = mx .concatenate ([x_t , x_t ], axis = 0 )
285292 t_cfg = mx .full ((batch_size * 2 ,), t , dtype = mx .float32 )
286293 v_out = model .forward_with_conditions (
287- x_t = x_cfg , t = t_cfg ,
294+ x_t = x_cfg ,
295+ t = t_cfg ,
288296 text_state = mx .concatenate (
289297 [text_state_cond , text_state_cond ], axis = 0
290298 ),
@@ -315,53 +323,77 @@ def sample_euler_cfg(
315323 joint_scale = cfg_scale_text if has_text_cfg else cfg_scale_speaker
316324
317325 v_cond = model .forward_with_conditions (
318- x_t = x_t , t = t_arr ,
319- text_state = text_state_cond , text_mask = text_mask_cond ,
320- speaker_state = speaker_state_cond , speaker_mask = speaker_mask_cond ,
321- kv_text = kv_text_cond , kv_speaker = kv_speaker_cond ,
326+ x_t = x_t ,
327+ t = t_arr ,
328+ text_state = text_state_cond ,
329+ text_mask = text_mask_cond ,
330+ speaker_state = speaker_state_cond ,
331+ speaker_mask = speaker_mask_cond ,
332+ kv_text = kv_text_cond ,
333+ kv_speaker = kv_speaker_cond ,
322334 )
323335 v_uncond = model .forward_with_conditions (
324- x_t = x_t , t = t_arr ,
325- text_state = text_state_uncond , text_mask = text_mask_uncond ,
326- speaker_state = speaker_state_uncond , speaker_mask = speaker_mask_uncond ,
327- kv_text = kv_text_uncond_joint , kv_speaker = kv_speaker_uncond_joint ,
336+ x_t = x_t ,
337+ t = t_arr ,
338+ text_state = text_state_uncond ,
339+ text_mask = text_mask_uncond ,
340+ speaker_state = speaker_state_uncond ,
341+ speaker_mask = speaker_mask_uncond ,
342+ kv_text = kv_text_uncond_joint ,
343+ kv_speaker = kv_speaker_uncond_joint ,
328344 )
329345 v_pred = v_cond + joint_scale * (v_cond - v_uncond )
330346
331347 else : # alternating
332348 v_cond = model .forward_with_conditions (
333- x_t = x_t , t = t_arr ,
334- text_state = text_state_cond , text_mask = text_mask_cond ,
335- speaker_state = speaker_state_cond , speaker_mask = speaker_mask_cond ,
336- kv_text = kv_text_cond , kv_speaker = kv_speaker_cond ,
349+ x_t = x_t ,
350+ t = t_arr ,
351+ text_state = text_state_cond ,
352+ text_mask = text_mask_cond ,
353+ speaker_state = speaker_state_cond ,
354+ speaker_mask = speaker_mask_cond ,
355+ kv_text = kv_text_cond ,
356+ kv_speaker = kv_speaker_cond ,
337357 )
338358 use_text_uncond = (has_text_cfg and has_speaker_cfg and i % 2 == 0 ) or (
339359 has_text_cfg and not has_speaker_cfg
340360 )
341361 if use_text_uncond :
342362 v_uncond = model .forward_with_conditions (
343- x_t = x_t , t = t_arr ,
344- text_state = text_state_uncond , text_mask = text_mask_uncond ,
345- speaker_state = speaker_state_cond , speaker_mask = speaker_mask_cond ,
346- kv_text = kv_text_uncond_alt , kv_speaker = kv_speaker_cond ,
363+ x_t = x_t ,
364+ t = t_arr ,
365+ text_state = text_state_uncond ,
366+ text_mask = text_mask_uncond ,
367+ speaker_state = speaker_state_cond ,
368+ speaker_mask = speaker_mask_cond ,
369+ kv_text = kv_text_uncond_alt ,
370+ kv_speaker = kv_speaker_cond ,
347371 )
348372 v_pred = v_cond + cfg_scale_text * (v_cond - v_uncond )
349373 else :
350374 v_uncond = model .forward_with_conditions (
351- x_t = x_t , t = t_arr ,
352- text_state = text_state_cond , text_mask = text_mask_cond ,
353- speaker_state = speaker_state_uncond , speaker_mask = speaker_mask_uncond ,
354- kv_text = kv_text_cond , kv_speaker = kv_speaker_uncond_alt ,
375+ x_t = x_t ,
376+ t = t_arr ,
377+ text_state = text_state_cond ,
378+ text_mask = text_mask_cond ,
379+ speaker_state = speaker_state_uncond ,
380+ speaker_mask = speaker_mask_uncond ,
381+ kv_text = kv_text_cond ,
382+ kv_speaker = kv_speaker_uncond_alt ,
355383 )
356384 v_pred = v_cond + cfg_scale_speaker * (v_cond - v_uncond )
357385
358386 else :
359387 # no CFG this step
360388 v_pred = model .forward_with_conditions (
361- x_t = x_t , t = t_arr ,
362- text_state = text_state_cond , text_mask = text_mask_cond ,
363- speaker_state = speaker_state_cond , speaker_mask = speaker_mask_cond ,
364- kv_text = kv_text_cond , kv_speaker = kv_speaker_cond ,
389+ x_t = x_t ,
390+ t = t_arr ,
391+ text_state = text_state_cond ,
392+ text_mask = text_mask_cond ,
393+ speaker_state = speaker_state_cond ,
394+ speaker_mask = speaker_mask_cond ,
395+ kv_text = kv_text_cond ,
396+ kv_speaker = kv_speaker_cond ,
365397 )
366398
367399 # optional temporal score rescaling
0 commit comments