Skip to content

Commit 527b8c4

Browse files
Support local transformer in longform magpietts (#15296)
Signed-off-by: subhankar-ghosh <[email protected]>
1 parent 798f0c6 commit 527b8c4

File tree

2 files changed

+66
-13
lines changed

2 files changed

+66
-13
lines changed

nemo/collections/tts/models/magpietts.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3658,6 +3658,7 @@ def do_tts(
36583658
end_of_text=end_of_text,
36593659
beginning_of_text=beginning_of_text,
36603660
use_cfg=use_cfg,
3661+
use_local_transformer_for_inference=True,
36613662
)
36623663
if output.predicted_codes_lens[0] > 0:
36633664
all_codes.append(output.predicted_codes[0, :, : output.predicted_codes_lens[0]])
@@ -4011,7 +4012,7 @@ def _run_longform_forward_with_cfg(
40114012
dummy_additional_decoder_input: Optional[torch.Tensor],
40124013
dummy_addition_dec_mask: Optional[torch.Tensor],
40134014
batch_size: int,
4014-
) -> Tuple[torch.Tensor, Any]:
4015+
) -> Tuple[torch.Tensor, Any, torch.Tensor]:
40154016
"""
40164017
Run forward pass with optional classifier-free guidance.
40174018
@@ -4029,7 +4030,7 @@ def _run_longform_forward_with_cfg(
40294030
batch_size: Number of items in the batch.
40304031
40314032
Returns:
4032-
Tuple of (logits, attention_probs).
4033+
Tuple of (logits, attention_probs, decoder_output).
40334034
"""
40344035
if use_cfg:
40354036
# Combine conditional and unconditional inputs
@@ -4049,7 +4050,7 @@ def _run_longform_forward_with_cfg(
40494050
)
40504051
cfg_audio_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = dummy_addition_dec_mask
40514052

4052-
combined_logits, attn_probs, _ = self.forward(
4053+
combined_logits, attn_probs, dec_out = self.forward(
40534054
dec_input_embedded=cfg_audio_embedded,
40544055
dec_input_mask=cfg_audio_mask,
40554056
cond=cfg_cond,
@@ -4061,8 +4062,9 @@ def _run_longform_forward_with_cfg(
40614062
cond_logits = combined_logits[:batch_size]
40624063
uncond_logits = combined_logits[batch_size:]
40634064
all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits
4065+
# NOTE: Keep dec_out doubled for local transformer CFG handling
40644066
else:
4065-
all_code_logits, attn_probs, _ = self.forward(
4067+
all_code_logits, attn_probs, dec_out = self.forward(
40664068
dec_input_embedded=audio_codes_embedded,
40674069
dec_input_mask=audio_codes_mask,
40684070
cond=context_tensors.cond,
@@ -4071,7 +4073,7 @@ def _run_longform_forward_with_cfg(
40714073
multi_encoder_mapping=context_tensors.multi_encoder_mapping,
40724074
)
40734075

4074-
return all_code_logits, attn_probs
4076+
return all_code_logits, attn_probs, dec_out
40754077

40764078
def _initialize_longform_attn_prior(
40774079
self,
@@ -4242,6 +4244,12 @@ def generate_long_form_speech(
42424244
end_of_text,
42434245
beginning_of_text,
42444246
use_cfg=True,
4247+
use_local_transformer_for_inference=False,
4248+
maskgit_n_steps=3,
4249+
maskgit_noise_scale=0.0,
4250+
maskgit_fixed_schedule=None,
4251+
maskgit_dynamic_cfg_scale=False,
4252+
maskgit_sampling_type=None,
42454253
):
42464254
"""
42474255
Generates speech for long-form text by progressively shifting through text tokens.
@@ -4258,6 +4266,12 @@ def generate_long_form_speech(
42584266
end_of_text (List[bool]): Whether entire text has been provided for each batch item.
42594267
beginning_of_text (bool): Whether this is the first chunk.
42604268
use_cfg (bool): Whether to use classifier-free guidance.
4269+
use_local_transformer_for_inference (bool): Whether to use local transformer for sampling.
4270+
maskgit_n_steps (int): Number of MaskGit refinement steps.
4271+
maskgit_noise_scale (float): Noise scale for MaskGit sampling.
4272+
maskgit_fixed_schedule (Optional[List[int]]): Fixed schedule for MaskGit.
4273+
maskgit_dynamic_cfg_scale (bool): Whether to use dynamic CFG scale in MaskGit.
4274+
maskgit_sampling_type (Optional[str]): Type of MaskGit sampling.
42614275
42624276
Returns:
42634277
InferBatchOutput: Contains predicted_codes, predicted_codes_lens, and empty audio fields.
@@ -4365,7 +4379,7 @@ def generate_long_form_speech(
43654379
attn_prior = [attn_prior, None]
43664380

43674381
# Run forward pass with optional CFG
4368-
all_code_logits, attn_probs = self._run_longform_forward_with_cfg(
4382+
all_code_logits, attn_probs, dec_out = self._run_longform_forward_with_cfg(
43694383
context_tensors=context_tensors,
43704384
audio_codes_embedded=_audio_codes_embedded,
43714385
audio_codes_mask=_audio_codes_mask,
@@ -4439,13 +4453,47 @@ def generate_long_form_speech(
44394453
unfinished_items = {k: v for k, v in state.unfinished_texts.items() if v}
44404454

44414455
all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook)
4442-
audio_codes_next = self.sample_codes_from_logits(
4443-
all_code_logits_t,
4444-
temperature=self.inference_parameters.temperature,
4445-
topk=self.inference_parameters.topk,
4446-
unfinished_items=unfinished_items,
4447-
finished_items=finished_items,
4448-
) # (B, num_codebooks)
4456+
4457+
if use_local_transformer_for_inference:
4458+
if self.local_transformer_type == LocalTransformerType.AR:
4459+
# Autoregressive sampling with local transformer
4460+
audio_codes_next = self.local_transformer_sample_autoregressive(
4461+
dec_output=dec_out[:, -1, :],
4462+
temperature=self.inference_parameters.temperature,
4463+
topk=self.inference_parameters.topk,
4464+
unfinished_items=unfinished_items,
4465+
finished_items=finished_items,
4466+
use_cfg=use_cfg,
4467+
cfg_scale=cfg_scale,
4468+
use_kv_cache=self.inference_parameters.use_LT_kv_cache,
4469+
)
4470+
elif self.local_transformer_type == LocalTransformerType.MASKGIT:
4471+
audio_codes_next = self.local_transformer_sample_maskgit(
4472+
dec_output=dec_out[:, -1, :],
4473+
temperature=self.inference_parameters.temperature,
4474+
topk=self.inference_parameters.topk,
4475+
unfinished_items=unfinished_items,
4476+
finished_items=finished_items,
4477+
use_cfg=use_cfg,
4478+
cfg_scale=cfg_scale,
4479+
n_steps=maskgit_n_steps,
4480+
noise_scale=maskgit_noise_scale,
4481+
fixed_schedule=maskgit_fixed_schedule,
4482+
dynamic_cfg_scale=maskgit_dynamic_cfg_scale,
4483+
sampling_type=maskgit_sampling_type,
4484+
)
4485+
else:
4486+
raise ValueError(
4487+
f"Local transformer inference requested but local transformer type is {self.local_transformer_type}"
4488+
)
4489+
else:
4490+
audio_codes_next = self.sample_codes_from_logits(
4491+
all_code_logits_t,
4492+
temperature=self.inference_parameters.temperature,
4493+
topk=self.inference_parameters.topk,
4494+
unfinished_items=unfinished_items,
4495+
finished_items=finished_items,
4496+
) # (B, num_codebooks)
44494497
all_codes_next_argmax = self.sample_codes_from_logits(
44504498
all_code_logits_t,
44514499
temperature=self.longform_config.argmax_temperature,

nemo/collections/tts/modules/magpietts_inference/inference.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,11 @@ def _run_longform_inference(
649649
end_of_text=is_end_of_text,
650650
beginning_of_text=beginning_of_text,
651651
use_cfg=self.config.use_cfg,
652+
use_local_transformer_for_inference=self.config.use_local_transformer,
653+
maskgit_n_steps=self.config.maskgit_n_steps,
654+
maskgit_noise_scale=self.config.maskgit_noise_scale,
655+
maskgit_fixed_schedule=self.config.maskgit_fixed_schedule,
656+
maskgit_sampling_type=self.config.maskgit_sampling_type,
652657
)
653658

654659
# Unpack output - generate_long_form_speech returns InferBatchOutput

0 commit comments

Comments
 (0)