@@ -3658,6 +3658,7 @@ def do_tts(
36583658 end_of_text = end_of_text ,
36593659 beginning_of_text = beginning_of_text ,
36603660 use_cfg = use_cfg ,
3661+ use_local_transformer_for_inference = True ,
36613662 )
36623663 if output .predicted_codes_lens [0 ] > 0 :
36633664 all_codes .append (output .predicted_codes [0 , :, : output .predicted_codes_lens [0 ]])
@@ -4011,7 +4012,7 @@ def _run_longform_forward_with_cfg(
40114012 dummy_additional_decoder_input : Optional [torch .Tensor ],
40124013 dummy_addition_dec_mask : Optional [torch .Tensor ],
40134014 batch_size : int ,
4014- ) -> Tuple [torch .Tensor , Any ]:
4015+ ) -> Tuple [torch .Tensor , Any , torch . Tensor ]:
40154016 """
40164017 Run forward pass with optional classifier-free guidance.
40174018
@@ -4029,7 +4030,7 @@ def _run_longform_forward_with_cfg(
40294030 batch_size: Number of items in the batch.
40304031
40314032 Returns:
4032- Tuple of (logits, attention_probs).
4033+ Tuple of (logits, attention_probs, decoder_output ).
40334034 """
40344035 if use_cfg :
40354036 # Combine conditional and unconditional inputs
@@ -4049,7 +4050,7 @@ def _run_longform_forward_with_cfg(
40494050 )
40504051 cfg_audio_mask [batch_size :, : dummy_additional_decoder_input .size (1 )] = dummy_addition_dec_mask
40514052
4052- combined_logits , attn_probs , _ = self .forward (
4053+ combined_logits , attn_probs , dec_out = self .forward (
40534054 dec_input_embedded = cfg_audio_embedded ,
40544055 dec_input_mask = cfg_audio_mask ,
40554056 cond = cfg_cond ,
@@ -4061,8 +4062,9 @@ def _run_longform_forward_with_cfg(
40614062 cond_logits = combined_logits [:batch_size ]
40624063 uncond_logits = combined_logits [batch_size :]
40634064 all_code_logits = (1 - cfg_scale ) * uncond_logits + cfg_scale * cond_logits
4065+ # NOTE: Keep dec_out doubled for local transformer CFG handling
40644066 else :
4065- all_code_logits , attn_probs , _ = self .forward (
4067+ all_code_logits , attn_probs , dec_out = self .forward (
40664068 dec_input_embedded = audio_codes_embedded ,
40674069 dec_input_mask = audio_codes_mask ,
40684070 cond = context_tensors .cond ,
@@ -4071,7 +4073,7 @@ def _run_longform_forward_with_cfg(
40714073 multi_encoder_mapping = context_tensors .multi_encoder_mapping ,
40724074 )
40734075
4074- return all_code_logits , attn_probs
4076+ return all_code_logits , attn_probs , dec_out
40754077
40764078 def _initialize_longform_attn_prior (
40774079 self ,
@@ -4242,6 +4244,12 @@ def generate_long_form_speech(
42424244 end_of_text ,
42434245 beginning_of_text ,
42444246 use_cfg = True ,
4247+ use_local_transformer_for_inference = False ,
4248+ maskgit_n_steps = 3 ,
4249+ maskgit_noise_scale = 0.0 ,
4250+ maskgit_fixed_schedule = None ,
4251+ maskgit_dynamic_cfg_scale = False ,
4252+ maskgit_sampling_type = None ,
42454253 ):
42464254 """
42474255 Generates speech for long-form text by progressively shifting through text tokens.
@@ -4258,6 +4266,12 @@ def generate_long_form_speech(
42584266 end_of_text (List[bool]): Whether entire text has been provided for each batch item.
42594267 beginning_of_text (bool): Whether this is the first chunk.
42604268 use_cfg (bool): Whether to use classifier-free guidance.
4269+ use_local_transformer_for_inference (bool): Whether to use local transformer for sampling.
4270+ maskgit_n_steps (int): Number of MaskGit refinement steps.
4271+ maskgit_noise_scale (float): Noise scale for MaskGit sampling.
4272+ maskgit_fixed_schedule (Optional[List[int]]): Fixed schedule for MaskGit.
4273+ maskgit_dynamic_cfg_scale (bool): Whether to use dynamic CFG scale in MaskGit.
4274+ maskgit_sampling_type (Optional[str]): Type of MaskGit sampling.
42614275
42624276 Returns:
42634277 InferBatchOutput: Contains predicted_codes, predicted_codes_lens, and empty audio fields.
@@ -4365,7 +4379,7 @@ def generate_long_form_speech(
43654379 attn_prior = [attn_prior , None ]
43664380
43674381 # Run forward pass with optional CFG
4368- all_code_logits , attn_probs = self ._run_longform_forward_with_cfg (
4382+ all_code_logits , attn_probs , dec_out = self ._run_longform_forward_with_cfg (
43694383 context_tensors = context_tensors ,
43704384 audio_codes_embedded = _audio_codes_embedded ,
43714385 audio_codes_mask = _audio_codes_mask ,
@@ -4439,13 +4453,47 @@ def generate_long_form_speech(
44394453 unfinished_items = {k : v for k , v in state .unfinished_texts .items () if v }
44404454
44414455 all_code_logits_t = all_code_logits [:, - 1 , :] # (B, num_codebooks * num_tokens_per_codebook)
4442- audio_codes_next = self .sample_codes_from_logits (
4443- all_code_logits_t ,
4444- temperature = self .inference_parameters .temperature ,
4445- topk = self .inference_parameters .topk ,
4446- unfinished_items = unfinished_items ,
4447- finished_items = finished_items ,
4448- ) # (B, num_codebooks)
4456+
4457+ if use_local_transformer_for_inference :
4458+ if self .local_transformer_type == LocalTransformerType .AR :
4459+ # Autoregressive sampling with local transformer
4460+ audio_codes_next = self .local_transformer_sample_autoregressive (
4461+ dec_output = dec_out [:, - 1 , :],
4462+ temperature = self .inference_parameters .temperature ,
4463+ topk = self .inference_parameters .topk ,
4464+ unfinished_items = unfinished_items ,
4465+ finished_items = finished_items ,
4466+ use_cfg = use_cfg ,
4467+ cfg_scale = cfg_scale ,
4468+ use_kv_cache = self .inference_parameters .use_LT_kv_cache ,
4469+ )
4470+ elif self .local_transformer_type == LocalTransformerType .MASKGIT :
4471+ audio_codes_next = self .local_transformer_sample_maskgit (
4472+ dec_output = dec_out [:, - 1 , :],
4473+ temperature = self .inference_parameters .temperature ,
4474+ topk = self .inference_parameters .topk ,
4475+ unfinished_items = unfinished_items ,
4476+ finished_items = finished_items ,
4477+ use_cfg = use_cfg ,
4478+ cfg_scale = cfg_scale ,
4479+ n_steps = maskgit_n_steps ,
4480+ noise_scale = maskgit_noise_scale ,
4481+ fixed_schedule = maskgit_fixed_schedule ,
4482+ dynamic_cfg_scale = maskgit_dynamic_cfg_scale ,
4483+ sampling_type = maskgit_sampling_type ,
4484+ )
4485+ else :
4486+ raise ValueError (
4487+ f"Local transformer inference requested but local transformer type is { self .local_transformer_type } "
4488+ )
4489+ else :
4490+ audio_codes_next = self .sample_codes_from_logits (
4491+ all_code_logits_t ,
4492+ temperature = self .inference_parameters .temperature ,
4493+ topk = self .inference_parameters .topk ,
4494+ unfinished_items = unfinished_items ,
4495+ finished_items = finished_items ,
4496+ ) # (B, num_codebooks)
44494497 all_codes_next_argmax = self .sample_codes_from_logits (
44504498 all_code_logits_t ,
44514499 temperature = self .longform_config .argmax_temperature ,
0 commit comments