Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
44c4644
Longform TTS using magpietts
subhankar-ghosh Dec 19, 2025
335007f
Apply isort and black reformatting
subhankar-ghosh Dec 19, 2025
e3584d8
Using LongformDecoderState in LongForm Magpietts
subhankar-ghosh Dec 19, 2025
eff4703
Using LongformDecoderState in LongForm Magpietts
subhankar-ghosh Dec 19, 2025
73ddc3d
Apply isort and black reformatting
subhankar-ghosh Dec 19, 2025
d2714a7
Potential fix for code scanning alert no. 16815: Non-standard excepti…
subhankar-ghosh Dec 19, 2025
f7fce00
Update nemo/collections/tts/data/text_to_speech_dataset.py
subhankar-ghosh Dec 19, 2025
132434b
Update nemo/collections/tts/models/magpietts.py
subhankar-ghosh Dec 19, 2025
8a4e144
Update nemo/collections/tts/models/magpietts.py
subhankar-ghosh Dec 19, 2025
0f41d39
Update nemo/collections/tts/data/text_to_speech_dataset.py
subhankar-ghosh Dec 20, 2025
312f127
Combining Inference runner, using data classes in longform
subhankar-ghosh Dec 20, 2025
bd187b5
Merge branch 'magpietts_os_longform' of github.com:NVIDIA-NeMo/NeMo i…
subhankar-ghosh Dec 20, 2025
577eff9
Apply isort and black reformatting
subhankar-ghosh Dec 20, 2025
a8610d8
make LongFormTTSInferenceDataset a subclass of MagpieTTSDataset
subhankar-ghosh Dec 20, 2025
9e139c4
make LongFormTTSInferenceDataset a subclass of MagpieTTSDataset
subhankar-ghosh Dec 20, 2025
fe64ea8
Apply isort and black reformatting
subhankar-ghosh Dec 20, 2025
e6a248d
Update nemo/collections/tts/models/magpietts.py
subhankar-ghosh Dec 20, 2025
294439e
Update nemo/collections/tts/models/magpietts.py
subhankar-ghosh Dec 20, 2025
1e9004b
Remove redundant code from inference.py
subhankar-ghosh Dec 20, 2025
99ffb47
Merge branch 'magpietts_os_longform' of github.com:NVIDIA-NeMo/NeMo i…
subhankar-ghosh Dec 20, 2025
7faab9e
Adding longform test cases.
subhankar-ghosh Dec 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/cicd-main-speech.yml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ jobs:
script: L2_TTS_InferEvaluate_Magpietts_ZeroShot
- runner: self-hosted-azure
script: L2_TTS_InferEvaluate_Magpietts_SeenSpeakers
- runner: self-hosted-azure
script: L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot
needs: [unit-tests]
runs-on: ${{ matrix.runner }}
name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }}
Expand Down
38 changes: 37 additions & 1 deletion examples/tts/magpietts_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def run_inference_and_evaluation(
) -> Tuple[Optional[float], Optional[float]]:
"""Run inference and optional evaluation on specified datasets.

Longform inference is automatically detected based on text characteristics
when longform_mode="auto" (default). Use longform_mode="always" or "never"
for explicit control.

Args:
model_config: Configuration for loading the model.
inference_config: Configuration for inference.
Expand Down Expand Up @@ -166,7 +170,8 @@ def run_inference_and_evaluation(
# Build full checkpoint identifier
full_checkpoint_name = f"{checkpoint_name}_{inference_config.build_identifier()}_SV_{eval_config.sv_model}"

# Create inference runner
# Create inference runner (auto-detects longform based on config.longform_mode)
logging.info(f"Longform mode: {inference_config.longform_mode}")
runner = MagpieInferenceRunner(model, inference_config)

# Tracking metrics across datasets
Expand Down Expand Up @@ -396,6 +401,25 @@ def create_argument_parser() -> argparse.ArgumentParser:
infer_group.add_argument('--batch_size', type=int, default=32)
infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance')
infer_group.add_argument('--cfg_scale', type=float, default=2.5)
infer_group.add_argument(
'--longform_mode',
type=str,
default='auto',
choices=['auto', 'always', 'never'],
help='Longform inference mode: auto (detect from text), always, or never',
)
infer_group.add_argument(
'--longform_word_threshold',
type=int,
default=40,
help='Word threshold for auto-detection of longform text',
)
infer_group.add_argument(
'--longform_max_decoder_steps',
type=int,
default=50000,
help='Maximum decoder steps for longform inference',
)

# Attention prior arguments
prior_group = parser.add_argument_group('Attention Prior')
Expand Down Expand Up @@ -495,12 +519,22 @@ def main():
parser.error("You must provide either:\n" " 1. --hparams_files and --checkpoint_files\n" " 2. --nemo_files")

# Build configurations
# Use higher max_decoder_steps for longform inference when mode is 'always'
if args.longform_mode == 'always':
max_decoder_steps = args.longform_max_decoder_steps
elif args.longform_mode == 'auto':
# Use longform steps if any text appears long (will be checked in runner)
max_decoder_steps = args.longform_max_decoder_steps
else: # 'never'
max_decoder_steps = 440

inference_config = InferenceConfig(
temperature=args.temperature,
topk=args.topk,
batch_size=args.batch_size,
use_cfg=args.use_cfg,
cfg_scale=args.cfg_scale,
max_decoder_steps=max_decoder_steps,
apply_attention_prior=args.apply_attention_prior,
attention_prior_epsilon=args.attention_prior_epsilon,
attention_prior_lookahead_window=args.attention_prior_lookahead_window,
Expand All @@ -509,6 +543,8 @@ def main():
start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps,
use_local_transformer=args.use_local_transformer,
maskgit_n_steps=args.maskgit_n_steps,
longform_mode=args.longform_mode,
longform_word_threshold=args.longform_word_threshold,
maskgit_noise_scale=args.maskgit_noise_scale,
maskgit_fixed_schedule=args.maskgit_fixed_schedule,
maskgit_sampling_type=args.maskgit_sampling_type,
Expand Down
158 changes: 158 additions & 0 deletions nemo/collections/tts/data/text_to_speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nemo.collections.tts.parts.utils.tts_dataset_utils import (
_read_audio,
beta_binomial_prior_distribution,
chunk_and_tokenize_text_by_sentence,
filter_dataset_by_duration,
get_weighted_sampler,
load_audio,
Expand Down Expand Up @@ -629,6 +630,10 @@ def __getitem__(self, index):
if "reward" in data.manifest_entry:
example["reward"] = data.manifest_entry["reward"]

# Get speaker index if available (for models with baked context embeddings)
if 'speaker_index' in data.manifest_entry:
example['speaker_index'] = data.manifest_entry['speaker_index']

return example

def collate_fn(self, batch: List[dict]):
Expand All @@ -653,6 +658,7 @@ def collate_fn(self, batch: List[dict]):
reward_list = []
raw_text_list = []
language_list = []
speaker_indices_list = []
for example in batch:
dataset_name_list.append(example["dataset_name"])
audio_filepath_list.append(example["audio_filepath"])
Expand Down Expand Up @@ -690,6 +696,9 @@ def collate_fn(self, batch: List[dict]):
if 'reward' in example:
reward_list.append(example['reward'])

if 'speaker_index' in example:
speaker_indices_list.append(example['speaker_index'])

if self.include_align_prior:
prior_list.append(example["align_prior"])

Expand Down Expand Up @@ -762,6 +771,9 @@ def collate_fn(self, batch: List[dict]):
if len(reward_list) > 0:
batch_dict['rewards'] = torch.FloatTensor(reward_list)

if len(speaker_indices_list) > 0:
batch_dict['speaker_indices'] = torch.tensor(speaker_indices_list, dtype=torch.int64)

# Assert only ONE of context_audio or context_audio_codes in the batch
assert ('audio' in batch_dict) ^ ('audio_codes' in batch_dict)

Expand Down Expand Up @@ -798,3 +810,149 @@ def collate_fn(self, batch: List[dict]):
chosen_collated = super().collate_fn(chosen_batch)
rejected_collated = super().collate_fn(rejected_batch)
return {"chosen": chosen_collated, "rejected": rejected_collated}


class LongFormTTSInferenceDataset(MagpieTTSDataset):
"""
Dataset for longform TTS inference with sentence-level text chunking.

Inherits from MagpieTTSDataset to reuse context audio loading, text conditioning,
and other preprocessing logic. Adds sentence-level text chunking on top.

Args:
dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset).
sample_rate: Audio sample rate.
tokenizer_name: Name of the tokenizer to use for sentence chunking.
codec_model_samples_per_frame: Samples per codec frame.
eos_id: End-of-sequence token ID.
audio_bos_id: Audio BOS token ID (for target audio).
audio_eos_id: Audio EOS token ID (for target audio).
context_audio_bos_id: Context audio BOS token ID.
context_audio_eos_id: Context audio EOS token ID.
num_audio_codebooks: Number of audio codebooks.
context_duration_min: Minimum context duration in seconds.
context_duration_max: Maximum context duration in seconds.
use_text_conditioning_tokenizer: Whether model uses text conditioning encoder.
text_conditioning_tokenizer_name: Name of text conditioning tokenizer.
pad_context_text_to_max_duration: Whether to pad context text.
load_16khz_audio: Whether to load 16kHz audio for SV model.
"""

def __init__(
self,
dataset_meta: Dict[str, Any],
sample_rate: int,
tokenizer_name: str,
codec_model_samples_per_frame: int,
eos_id: int,
audio_bos_id: int,
audio_eos_id: int,
context_audio_bos_id: int,
context_audio_eos_id: int,
num_audio_codebooks: int,
context_duration_min: float = 3.0,
context_duration_max: float = 10.0,
use_text_conditioning_tokenizer: bool = False,
text_conditioning_tokenizer_name: str = None,
pad_context_text_to_max_duration: bool = False,
load_16khz_audio: bool = False,
**kwargs,
):
# Initialize parent - handles manifest reading and context audio loading
super().__init__(
dataset_meta=dataset_meta,
sample_rate=sample_rate,
codec_model_samples_per_frame=codec_model_samples_per_frame,
eos_id=eos_id,
audio_bos_id=audio_bos_id,
audio_eos_id=audio_eos_id,
context_audio_bos_id=context_audio_bos_id,
context_audio_eos_id=context_audio_eos_id,
num_audio_codebooks=num_audio_codebooks,
context_duration_min=context_duration_min,
context_duration_max=context_duration_max,
use_text_conditioning_tokenizer=use_text_conditioning_tokenizer,
text_conditioning_tokenizer_name=text_conditioning_tokenizer_name,
pad_context_text_to_max_duration=pad_context_text_to_max_duration,
load_16khz_audio=load_16khz_audio,
load_cached_codes_if_available=True, # Prefer codes for inference
dataset_type='test',
**kwargs,
)
self.tokenizer_name = tokenizer_name

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Add sentence chunking on top of parent's __getitem__.

Returns:
Dictionary containing all parent fields plus:
- idx: Sample index
- chunked_tokens: List of tokenized text chunks (per sentence)
- chunked_tokens_len: List of token lengths
- entry: Original manifest entry
"""
# Get text for sentence chunking
data = self.data_samples[idx]
text = data.text # entry.get("normalized_text", entry.get("text", ""))

# Sentence chunking (longform-specific)
chunked_tokens, chunked_tokens_len, _ = chunk_and_tokenize_text_by_sentence(
text,
self.tokenizer_name,
self.text_tokenizer,
self.eos_id,
)

# Handle empty text edge case
if not chunked_tokens:
chunked_tokens = [torch.tensor([self.eos_id], dtype=torch.int32)]
chunked_tokens_len = [1]

# Call parent to get ALL the context audio, text conditioning, etc.
example = super().__getitem__(idx)

# Add longform-specific fields
example['idx'] = idx
example['chunked_tokens'] = chunked_tokens
example['chunked_tokens_len'] = chunked_tokens_len

return example

def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Collate function for batching longform samples.

Calls parent's collate_fn to handle context audio, text conditioning, etc.,
then adds longform-specific fields (chunked_tokens).
"""
# Call parent's collate_fn to handle all standard fields
batch_dict = super().collate_fn(batch)

# Add longform-specific fields
indices = []
chunked_tokens_list = []
chunked_tokens_lens_list = []

# Find max number of chunks across batch
max_num_chunks = max(len(sample['chunked_tokens']) for sample in batch)

for sample in batch:
indices.append(sample['idx'])

# Pad chunked tokens to max_num_chunks with single EOS token
num_padding = max_num_chunks - len(sample['chunked_tokens'])
padded_tokens = sample['chunked_tokens'] + [
torch.tensor([self.eos_id], dtype=torch.int32) for _ in range(num_padding)
]
padded_lens = sample['chunked_tokens_len'] + [1] * num_padding

chunked_tokens_list.append(padded_tokens)
chunked_tokens_lens_list.append(padded_lens)

# Add longform-specific fields to batch_dict
batch_dict['idx'] = indices
batch_dict['chunked_tokens'] = chunked_tokens_list
batch_dict['chunked_tokens_lens'] = chunked_tokens_lens_list

return batch_dict
Loading
Loading