diff --git a/.github/workflows/cicd-main-speech.yml b/.github/workflows/cicd-main-speech.yml index 53c79af0b5c8..ef85c2115a8f 100644 --- a/.github/workflows/cicd-main-speech.yml +++ b/.github/workflows/cicd-main-speech.yml @@ -193,6 +193,8 @@ jobs: script: L2_TTS_InferEvaluate_Magpietts_ZeroShot - runner: self-hosted-azure script: L2_TTS_InferEvaluate_Magpietts_SeenSpeakers + - runner: self-hosted-azure + script: L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot needs: [unit-tests] runs-on: ${{ matrix.runner }} name: ${{ matrix.is-optional && 'PLEASEFIXME_' || '' }}${{ matrix.script }} diff --git a/examples/tts/magpietts_inference.py b/examples/tts/magpietts_inference.py index f581555b9739..fe8e4068923d 100644 --- a/examples/tts/magpietts_inference.py +++ b/examples/tts/magpietts_inference.py @@ -132,6 +132,10 @@ def run_inference_and_evaluation( ) -> Tuple[Optional[float], Optional[float]]: """Run inference and optional evaluation on specified datasets. + Longform inference is automatically detected based on text characteristics + when longform_mode="auto" (default). Use longform_mode="always" or "never" + for explicit control. + Args: model_config: Configuration for loading the model. inference_config: Configuration for inference. @@ -166,7 +170,8 @@ def run_inference_and_evaluation( # Build full checkpoint identifier full_checkpoint_name = f"{checkpoint_name}_{inference_config.build_identifier()}_SV_{eval_config.sv_model}" - # Create inference runner + # Create inference runner (auto-detects longform based on config.longform_mode) + logging.info(f"Longform mode: {inference_config.longform_mode}") runner = MagpieInferenceRunner(model, inference_config) # Tracking metrics across datasets @@ -396,6 +401,25 @@ def create_argument_parser() -> argparse.ArgumentParser: infer_group.add_argument('--batch_size', type=int, default=32) infer_group.add_argument('--use_cfg', action='store_true', help='Enable classifier-free guidance') infer_group.add_argument('--cfg_scale', type=float, default=2.5) + infer_group.add_argument( + '--longform_mode', + type=str, + default='auto', + choices=['auto', 'always', 'never'], + help='Longform inference mode: auto (detect from text), always, or never', + ) + infer_group.add_argument( + '--longform_word_threshold', + type=int, + default=40, + help='Word threshold for auto-detection of longform text', + ) + infer_group.add_argument( + '--longform_max_decoder_steps', + type=int, + default=50000, + help='Maximum decoder steps for longform inference', + ) # Attention prior arguments prior_group = parser.add_argument_group('Attention Prior') @@ -495,12 +519,22 @@ def main(): parser.error("You must provide either:\n" " 1. --hparams_files and --checkpoint_files\n" " 2. --nemo_files") # Build configurations + # Use higher max_decoder_steps for longform inference when mode is 'always' + if args.longform_mode == 'always': + max_decoder_steps = args.longform_max_decoder_steps + elif args.longform_mode == 'auto': + # Use longform steps if any text appears long (will be checked in runner) + max_decoder_steps = args.longform_max_decoder_steps + else: # 'never' + max_decoder_steps = 440 + inference_config = InferenceConfig( temperature=args.temperature, topk=args.topk, batch_size=args.batch_size, use_cfg=args.use_cfg, cfg_scale=args.cfg_scale, + max_decoder_steps=max_decoder_steps, apply_attention_prior=args.apply_attention_prior, attention_prior_epsilon=args.attention_prior_epsilon, attention_prior_lookahead_window=args.attention_prior_lookahead_window, @@ -509,6 +543,8 @@ def main(): start_prior_after_n_audio_steps=args.start_prior_after_n_audio_steps, use_local_transformer=args.use_local_transformer, maskgit_n_steps=args.maskgit_n_steps, + longform_mode=args.longform_mode, + longform_word_threshold=args.longform_word_threshold, maskgit_noise_scale=args.maskgit_noise_scale, maskgit_fixed_schedule=args.maskgit_fixed_schedule, maskgit_sampling_type=args.maskgit_sampling_type, diff --git a/nemo/collections/tts/data/text_to_speech_dataset.py b/nemo/collections/tts/data/text_to_speech_dataset.py index c8d615dae7da..c98e3b50fb76 100644 --- a/nemo/collections/tts/data/text_to_speech_dataset.py +++ b/nemo/collections/tts/data/text_to_speech_dataset.py @@ -30,6 +30,7 @@ from nemo.collections.tts.parts.utils.tts_dataset_utils import ( _read_audio, beta_binomial_prior_distribution, + chunk_and_tokenize_text_by_sentence, filter_dataset_by_duration, get_weighted_sampler, load_audio, @@ -629,6 +630,10 @@ def __getitem__(self, index): if "reward" in data.manifest_entry: example["reward"] = data.manifest_entry["reward"] + # Get speaker index if available (for models with baked context embeddings) + if 'speaker_index' in data.manifest_entry: + example['speaker_index'] = data.manifest_entry['speaker_index'] + return example def collate_fn(self, batch: List[dict]): @@ -653,6 +658,7 @@ def collate_fn(self, batch: List[dict]): reward_list = [] raw_text_list = [] language_list = [] + speaker_indices_list = [] for example in batch: dataset_name_list.append(example["dataset_name"]) audio_filepath_list.append(example["audio_filepath"]) @@ -690,6 +696,9 @@ def collate_fn(self, batch: List[dict]): if 'reward' in example: reward_list.append(example['reward']) + if 'speaker_index' in example: + speaker_indices_list.append(example['speaker_index']) + if self.include_align_prior: prior_list.append(example["align_prior"]) @@ -762,6 +771,9 @@ def collate_fn(self, batch: List[dict]): if len(reward_list) > 0: batch_dict['rewards'] = torch.FloatTensor(reward_list) + if len(speaker_indices_list) > 0: + batch_dict['speaker_indices'] = torch.tensor(speaker_indices_list, dtype=torch.int64) + # Assert only ONE of context_audio or context_audio_codes in the batch assert ('audio' in batch_dict) ^ ('audio_codes' in batch_dict) @@ -798,3 +810,149 @@ def collate_fn(self, batch: List[dict]): chosen_collated = super().collate_fn(chosen_batch) rejected_collated = super().collate_fn(rejected_batch) return {"chosen": chosen_collated, "rejected": rejected_collated} + + +class LongFormTTSInferenceDataset(MagpieTTSDataset): + """ + Dataset for longform TTS inference with sentence-level text chunking. + + Inherits from MagpieTTSDataset to reuse context audio loading, text conditioning, + and other preprocessing logic. Adds sentence-level text chunking on top. + + Args: + dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset). + sample_rate: Audio sample rate. + tokenizer_name: Name of the tokenizer to use for sentence chunking. + codec_model_samples_per_frame: Samples per codec frame. + eos_id: End-of-sequence token ID. + audio_bos_id: Audio BOS token ID (for target audio). + audio_eos_id: Audio EOS token ID (for target audio). + context_audio_bos_id: Context audio BOS token ID. + context_audio_eos_id: Context audio EOS token ID. + num_audio_codebooks: Number of audio codebooks. + context_duration_min: Minimum context duration in seconds. + context_duration_max: Maximum context duration in seconds. + use_text_conditioning_tokenizer: Whether model uses text conditioning encoder. + text_conditioning_tokenizer_name: Name of text conditioning tokenizer. + pad_context_text_to_max_duration: Whether to pad context text. + load_16khz_audio: Whether to load 16kHz audio for SV model. + """ + + def __init__( + self, + dataset_meta: Dict[str, Any], + sample_rate: int, + tokenizer_name: str, + codec_model_samples_per_frame: int, + eos_id: int, + audio_bos_id: int, + audio_eos_id: int, + context_audio_bos_id: int, + context_audio_eos_id: int, + num_audio_codebooks: int, + context_duration_min: float = 3.0, + context_duration_max: float = 10.0, + use_text_conditioning_tokenizer: bool = False, + text_conditioning_tokenizer_name: str = None, + pad_context_text_to_max_duration: bool = False, + load_16khz_audio: bool = False, + **kwargs, + ): + # Initialize parent - handles manifest reading and context audio loading + super().__init__( + dataset_meta=dataset_meta, + sample_rate=sample_rate, + codec_model_samples_per_frame=codec_model_samples_per_frame, + eos_id=eos_id, + audio_bos_id=audio_bos_id, + audio_eos_id=audio_eos_id, + context_audio_bos_id=context_audio_bos_id, + context_audio_eos_id=context_audio_eos_id, + num_audio_codebooks=num_audio_codebooks, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, + use_text_conditioning_tokenizer=use_text_conditioning_tokenizer, + text_conditioning_tokenizer_name=text_conditioning_tokenizer_name, + pad_context_text_to_max_duration=pad_context_text_to_max_duration, + load_16khz_audio=load_16khz_audio, + load_cached_codes_if_available=True, # Prefer codes for inference + dataset_type='test', + **kwargs, + ) + self.tokenizer_name = tokenizer_name + + def __getitem__(self, idx: int) -> Dict[str, Any]: + """ + Add sentence chunking on top of parent's __getitem__. + + Returns: + Dictionary containing all parent fields plus: + - idx: Sample index + - chunked_tokens: List of tokenized text chunks (per sentence) + - chunked_tokens_len: List of token lengths + - entry: Original manifest entry + """ + # Get text for sentence chunking + data = self.data_samples[idx] + text = data.text # entry.get("normalized_text", entry.get("text", "")) + + # Sentence chunking (longform-specific) + chunked_tokens, chunked_tokens_len, _ = chunk_and_tokenize_text_by_sentence( + text, + self.tokenizer_name, + self.text_tokenizer, + self.eos_id, + ) + + # Handle empty text edge case + if not chunked_tokens: + chunked_tokens = [torch.tensor([self.eos_id], dtype=torch.int32)] + chunked_tokens_len = [1] + + # Call parent to get ALL the context audio, text conditioning, etc. + example = super().__getitem__(idx) + + # Add longform-specific fields + example['idx'] = idx + example['chunked_tokens'] = chunked_tokens + example['chunked_tokens_len'] = chunked_tokens_len + + return example + + def collate_fn(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: + """ + Collate function for batching longform samples. + + Calls parent's collate_fn to handle context audio, text conditioning, etc., + then adds longform-specific fields (chunked_tokens). + """ + # Call parent's collate_fn to handle all standard fields + batch_dict = super().collate_fn(batch) + + # Add longform-specific fields + indices = [] + chunked_tokens_list = [] + chunked_tokens_lens_list = [] + + # Find max number of chunks across batch + max_num_chunks = max(len(sample['chunked_tokens']) for sample in batch) + + for sample in batch: + indices.append(sample['idx']) + + # Pad chunked tokens to max_num_chunks with single EOS token + num_padding = max_num_chunks - len(sample['chunked_tokens']) + padded_tokens = sample['chunked_tokens'] + [ + torch.tensor([self.eos_id], dtype=torch.int32) for _ in range(num_padding) + ] + padded_lens = sample['chunked_tokens_len'] + [1] * num_padding + + chunked_tokens_list.append(padded_tokens) + chunked_tokens_lens_list.append(padded_lens) + + # Add longform-specific fields to batch_dict + batch_dict['idx'] = indices + batch_dict['chunked_tokens'] = chunked_tokens_list + batch_dict['chunked_tokens_lens'] = chunked_tokens_lens_list + + return batch_dict diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 31b872368a3f..5e62a2b48dff 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import os import random import time -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -49,6 +50,7 @@ get_mask_from_lengths, plot_alignment_to_numpy, ) +from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging @@ -130,6 +132,106 @@ class ContextTensorsOutput: beta_binomial_attn_prior: Optional[torch.Tensor] = None +@dataclass +class LongformDecoderState: + """Tracks state during longform speech generation. + + This dataclass encapsulates all the mutable state variables used in the + autoregressive decoding loop of generate_long_form_speech, reducing parameter + passing and improving code organization. + + Attributes: + audio_codes_input: Current audio codes buffer. Shape: (B, num_codebooks, T). + audio_codes_lens: Length of each audio sequence. Shape: (B,). + audio_codes_mask: Mask for audio codes. Shape: (B, T). + attended_timestep_counter: List of dicts tracking attention counts per timestep. + all_predictions: List of predicted audio code tensors. + chunk_end_dict: Maps batch indices to their chunk end timesteps. + unfinished_texts: Maps batch indices to whether text is still being processed. + finished_texts_counter: Maps batch indices to counts of timesteps near text end. + attn_prior: Current attention prior tensor. Shape: (B, 1, T_text). + """ + + audio_codes_input: torch.Tensor + audio_codes_lens: torch.Tensor + audio_codes_mask: torch.Tensor + attended_timestep_counter: List[Dict[int, int]] + all_predictions: List[torch.Tensor] + chunk_end_dict: Dict[int, int] + unfinished_texts: Dict[int, bool] + finished_texts_counter: Dict[int, int] + attn_prior: Optional[torch.Tensor] = None + + +@dataclass +class LongformConfig: + """Immutable configuration for longform inference tuning parameters. + + These parameters control the behavior of longform (multi-chunk) speech generation. + Initialized once in MagpieTTSModel.__init__ and accessed via self.longform_config. + + Attributes: + history_len_heuristic: Maximum history tokens to retain across chunks. + prior_weights_init: Attention prior weights for chunk initialization. + prior_weights: Attention prior weights during generation (history, current, +1, +2, +3, +4). + finished_limit_with_eot: Steps after text end before allowing EOS. + finished_limit_without_eot: Steps after chunk end before allowing EOS. + forceful_chunk_end_threshold: Threshold for forceful chunk termination. + argmax_temperature: Temperature for argmax sampling in EOS detection. + short_sentence_threshold: Sentences shorter than this skip attention prior. + attention_sink_threshold: Times attended before position is considered a sink. + near_end_threshold: Positions from text end to consider "near end". + """ + + history_len_heuristic: int = 20 + prior_weights_init: Tuple[float, ...] = (0.5, 1.0, 0.8, 0.2, 0.2) + prior_weights: Tuple[float, ...] = (0.2, 1.0, 0.6, 0.4, 0.2, 0.2) + finished_limit_with_eot: int = 5 + finished_limit_without_eot: int = 1 + forceful_chunk_end_threshold: int = 3 + argmax_temperature: float = 0.01 + short_sentence_threshold: int = 35 + attention_sink_threshold: int = 10 + near_end_threshold: int = 3 + + +@dataclass +class LongformChunkState: + """Mutable state persisting across chunks during longform generation. + + Created by the inference runner via model.create_longform_chunk_state(), + passed to generate_long_form_speech(), and updated in-place across chunk iterations. + + Attributes: + batch_size: Number of items in the batch. + history_text: Text tokens from previous chunks. Shape: (B, T). + history_text_lens: Lengths of history text per batch item. Shape: (B,). + history_context_tensor: Encoder output from previous chunks. Shape: (B, T, E). + end_indices: Maps batch indices to overall timestep where they ended. + overall_idx: Global timestep counter across all chunks. + left_offset: Sliding window offset per batch item for attention tracking. + previous_attn_len: Attention lengths from previous chunk per batch item. + last_attended_timesteps: Tracking of attended positions across decoding. + """ + + batch_size: int + history_text: Optional[torch.Tensor] = None + history_text_lens: Optional[torch.Tensor] = None + history_context_tensor: Optional[torch.Tensor] = None + end_indices: Dict[int, int] = field(default_factory=dict) + overall_idx: int = 0 + left_offset: List[int] = field(default_factory=list) + previous_attn_len: List[int] = field(default_factory=list) + last_attended_timesteps: List[List[int]] = field(default_factory=list) + + def __post_init__(self): + """Initialize batch-sized lists if not provided.""" + if not self.left_offset: + self.left_offset = [0] * self.batch_size + if not self.last_attended_timesteps: + self.last_attended_timesteps = [[1] * self.batch_size] + + def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here @@ -451,6 +553,9 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): # Class-level cache for text normalizers. Used during inference. self._text_normalizers: Dict[str, Any] = {} + # Longform inference configuration (immutable tuning parameters) + self.longform_config = LongformConfig() + def _register_tokenizer_artifacts(self, cfg: DictConfig) -> None: """ Register tokenizer file artifacts (phoneme_dict, heteronyms, etc.) for .nemo packaging. @@ -2547,22 +2652,54 @@ def get_most_attended_text_timestep( lookahead_window_size, attended_timestep_counter, batch_size, + left_offset=[], ): """ Returns the most attended timestep for each batch item + + This method identifies which text token is most attended to within a lookahead window, starting from + the last attended timestep. It includes logic to detect attention sinks (tokens attended to excessively) + and move past them. The method also tracks how many times each timestep has been attended. + + Args: + alignment_attention_scores (torch.Tensor): Attention scores between audio and text tokens. + Shape: (batch_size, text_length). + last_attended_timesteps (list): List containing the last attended timestep for each batch item. + The last element [-1] should be a list/tensor of length batch_size. + text_lens (torch.Tensor): Length of text sequence for each batch item. Shape: (batch_size,). + lookahead_window_size (int): Size of the forward-looking window to search for the next attended + timestep. Determines how far ahead from the last attended timestep to look. + attended_timestep_counter (list): List of dictionaries (one per batch item) tracking how many + times each timestep has been attended. Used to detect attention sinks. + batch_size (int): Number of items in the batch. + left_offset (list, optional): List of offsets to adjust timestep indices for each batch item, + used in longform inference when text is provided in chunks. Relevant only in longform + generation. + + Returns: + tuple: A tuple containing: + - text_time_step_attended (list): List of integers, one per batch item, indicating the most + attended text timestep for that item. + - attended_timestep_counter (list): Updated counter tracking attendance frequency for each + timestep across all batch items. """ + if len(left_offset) == 0: + left_offset = [0 for _ in range(batch_size)] text_time_step_attended = [] for bidx in range(batch_size): last_attended_timestep = last_attended_timesteps[-1][bidx] if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: # This is probably an attention sink! Move to the next timestep last_attended_timestep += 1 + last_attended_timestep_in_this_window = last_attended_timestep - left_offset[bidx] window_size = lookahead_window_size - window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps - item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end] + window_end = min( + last_attended_timestep_in_this_window + window_size, text_lens[bidx] - 3 + ) # Ignore the last 3 timesteps + item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep_in_this_window:window_end] if item_attention_scores.size(0) == 0: # This means the sentence has ended - attended_timestep = text_lens[bidx].item() - 1 + attended_timestep = text_lens[bidx].item() - 1 + left_offset[bidx] else: attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep text_time_step_attended.append(attended_timestep) @@ -3369,7 +3506,7 @@ def do_tts( "en": ["english_phoneme", "english"], "de": ["german_phoneme", "german"], "es": ["spanish_phoneme", "spanish"], - "fr": ["french_phoneme", "french"], + "fr": ["french_chartokenizer", "french"], "it": ["italian_phoneme", "italian"], "vi": ["vietnamese_phoneme", "vietnamese"], "zh": ["mandarin_phoneme", "mandarin", "chinese"], @@ -3419,3 +3556,820 @@ def do_tts( @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: return [] + + def create_longform_chunk_state(self, batch_size: int) -> LongformChunkState: + """Create fresh state for longform inference over a batch. + + This method creates a LongformChunkState dataclass instance that tracks + mutable state across multiple calls to generate_long_form_speech() when + processing long text in chunks. + + The returned state object should be: + 1. Created once per batch by the inference runner + 2. Passed to each call of generate_long_form_speech() + 3. Updated in-place during generation + + Args: + batch_size: Number of items in the batch. + + Returns: + LongformChunkState with initialized state for the batch. + + Example: + >>> chunk_state = model.create_longform_chunk_state(batch_size=4) + >>> for chunk in text_chunks: + ... output = model.generate_long_form_speech(batch, chunk_state, ...) + """ + return LongformChunkState(batch_size=batch_size) + + def _set_attention_prior_weights( + self, + attn_prior: torch.Tensor, + batch_idx: int, + attended_pos: int, + text_len: int, + eps_sq: float, + ) -> None: + """ + Set attention prior weights around the currently attended position. + + Creates a distribution that: + - Strongly suppresses positions before (attended - 1) + - Peaks at the current attended position + - Gradually decays for lookahead positions + - Suppresses far-future positions + + Args: + attn_prior: Prior tensor to modify in-place. Shape: (B, 1, T_text). + batch_idx: Index of current batch item. + attended_pos: Currently attended text position (chunk-relative). + text_len: Length of text for this batch item. + eps_sq: Squared epsilon for strong suppression. + """ + prior_weights = self.longform_config.prior_weights + + # Suppress history (before attended - 1) + history_end = max(1, attended_pos - 1) + attn_prior[batch_idx, 0, :history_end] = eps_sq + + # Set weights around attended position + attn_prior[batch_idx, 0, history_end] = prior_weights[0] # History exposure + attn_prior[batch_idx, 0, attended_pos] = prior_weights[1] # Current (peak) + + # Lookahead positions with bounds checking + for offset, weight in enumerate(prior_weights[2:], start=1): + pos = attended_pos + offset + if pos < text_len: + attn_prior[batch_idx, 0, pos] = weight + + # Suppress far future (position +5 onwards) + future_start = attended_pos + len(prior_weights) - 1 + if future_start < text_len: + attn_prior[batch_idx, 0, future_start:] = eps_sq + + def _penalize_attention_sinks( + self, + attn_prior: torch.Tensor, + batch_idx: int, + attended_timestep_counter: Dict[int, int], + left_offset: int, + eps_sq: float, + ) -> None: + """ + Penalize timesteps that have been over-attended (attention sinks). + + When a position is attended more than the threshold, suppress all + positions up to and including it to force the model to move forward. + + Args: + attn_prior: Prior tensor to modify in-place. Shape: (B, 1, T_text). + batch_idx: Index of current batch item. + attended_timestep_counter: Dict tracking attention counts per timestep. + left_offset: Chunk offset for this batch item. + eps_sq: Squared epsilon for strong suppression. + """ + threshold = self.longform_config.attention_sink_threshold + + for timestep, count in attended_timestep_counter.items(): + if timestep > left_offset and count >= threshold: + logging.debug(f"Attention sink at timestep {timestep} for batch {batch_idx}, count: {count}") + relative_pos = timestep - left_offset + attn_prior[batch_idx, 0, : relative_pos + 1] = eps_sq + + def _update_text_completion_state( + self, + batch_idx: int, + attended_pos: int, + text_len: int, + is_finished: bool, + unfinished_texts: Dict[int, bool], + finished_texts_counter: Dict[int, int], + ) -> None: + """ + Update tracking state for text completion detection. + + A text is considered "near end" when the attended position is within + `longform_near_end_threshold` positions of the text end. + + Args: + batch_idx: Index of current batch item. + attended_pos: Currently attended text position (chunk-relative). + text_len: Length of text for this batch item. + is_finished: Whether this batch item has already finished. + unfinished_texts: Dict to update in-place. + finished_texts_counter: Dict to update in-place. + """ + is_near_end = attended_pos >= text_len - self.longform_config.near_end_threshold + + # Text is unfinished if not near end AND not already marked finished + unfinished_texts[batch_idx] = not is_near_end and not is_finished + + # Start counting when near end or already finished + if is_near_end or is_finished: + finished_texts_counter.setdefault(batch_idx, 0) + + def construct_longform_inference_prior( + self, + prior_epsilon: float, + cross_attention_scores: torch.Tensor, + text_lens: torch.Tensor, + text_time_step_attended: List[int], + attended_timestep_counter: List[Dict[int, int]], + unfinished_texts: Dict[int, bool], + finished_texts_counter: Dict[int, int], + end_indices: Dict[int, int], + chunk_end_dict: Dict[int, int], + batch_size: int, + left_offset: Optional[List[int]] = None, + ) -> Tuple[torch.Tensor, Dict[int, bool], Dict[int, int]]: + """ + Construct attention prior for longform inference with chunked text. + + Builds a soft attention prior that guides the decoder to attend to appropriate + text positions, preventing attention drift and encouraging monotonic progression. + + Args: + prior_epsilon: Base probability for non-targeted positions. + cross_attention_scores: Attention scores for shape/device inference. + Shape: (effective_batch, text_length). + text_lens: Length of text for each batch item. Shape: (batch_size,). + text_time_step_attended: Most attended text position (absolute) per batch item. + attended_timestep_counter: Per-batch dicts tracking attention counts per timestep. + unfinished_texts: Updated in-place. True if text still being processed. + finished_texts_counter: Updated in-place. Counts consecutive near-end timesteps. + end_indices: Batch indices that have reached end-of-sequence. + chunk_end_dict: Batch indices that have reached chunk end. + batch_size: Number of items in the batch. + left_offset: Chunk offset for each batch item. Defaults to zeros. + + Returns: + Tuple of (attention_prior, unfinished_texts, finished_texts_counter). + """ + # Initialize with safe default (avoid mutable default argument) + if left_offset is None: + left_offset = [0] * batch_size + + # Extract shape info and create prior tensor + device = cross_attention_scores.device + effective_batch = cross_attention_scores.shape[0] # 2 * batch_size if CFG else batch_size + text_dim = cross_attention_scores.shape[1] + eps_sq = prior_epsilon * prior_epsilon + + attn_prior = torch.full((effective_batch, 1, text_dim), prior_epsilon, device=device) + + # Process each batch item + for bidx in range(min(effective_batch, batch_size)): + text_len = int(text_lens[bidx]) + attended_pos = text_time_step_attended[bidx] - left_offset[bidx] + is_finished = bidx in end_indices or bidx in chunk_end_dict + + # Short sentences: uniform prior (no guidance needed) + if text_len <= self.longform_config.short_sentence_threshold: + attn_prior[bidx, 0, :] = 1.0 + else: + # Set attention weights around attended position + self._set_attention_prior_weights(attn_prior, bidx, attended_pos, text_len, eps_sq) + + # Penalize attention sinks (stuck positions) + if not is_finished: + self._penalize_attention_sinks( + attn_prior, bidx, attended_timestep_counter[bidx], left_offset[bidx], eps_sq + ) + + # Update text completion tracking + self._update_text_completion_state( + bidx, attended_pos, text_len, is_finished, unfinished_texts, finished_texts_counter + ) + + return attn_prior, unfinished_texts, finished_texts_counter + + @staticmethod + def _to_int(value: Union[int, torch.Tensor]) -> int: + """Convert tensor scalar to Python int if needed.""" + return value.item() if not isinstance(value, int) else value + + def _check_eos_and_update_state( + self, + chunk_state: LongformChunkState, + audio_codes_next: torch.Tensor, + all_codes_next_argmax: torch.Tensor, + chunk_end_dict: Dict[int, int], + finished_texts_counter: Dict[int, int], + end_of_text: List[bool], + eos_detection_method: 'EOSDetectionMethod', + current_step: int, + batch_size: int, + ) -> None: + """ + Check for EOS tokens and update chunk/end tracking state. + + Args: + chunk_state: Mutable state object tracking history across chunks. + audio_codes_next: Sampled audio codes. Shape: (B, num_codebooks). + all_codes_next_argmax: Argmax sampled codes for EOS detection. + chunk_end_dict: Maps batch indices to chunk end timesteps. + finished_texts_counter: Counter for near-end timesteps. + end_of_text: Whether text has ended for each batch item. + eos_detection_method: Method for detecting end-of-sequence. + current_step: Current decoding step index. + batch_size: Number of items in the batch. + """ + for item_idx in range(batch_size): + if item_idx in chunk_state.end_indices or item_idx in chunk_end_dict: + continue + + end_frame_index = self.detect_eos( + audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method + ) + + # End of speech detected. Update the state. + if end_frame_index != float('inf'): + if end_of_text[item_idx]: + # Speech for entire longform text has ended. Update the state. + chunk_state.end_indices[item_idx] = chunk_state.overall_idx + chunk_end_dict[item_idx] = current_step + logging.info( + f"End detected for item {item_idx} at local timestep {current_step} " + f"and overall timestep {chunk_state.overall_idx}" + ) + elif item_idx not in chunk_end_dict: + # Chunk end detected. Update the state. + chunk_end_dict[item_idx] = current_step + logging.info(f"Chunk end detected for item {item_idx} at local timestep {current_step}") + elif ( + not end_of_text[item_idx] + and finished_texts_counter.get(item_idx, -1) >= self.longform_config.forceful_chunk_end_threshold + ): + chunk_end_dict[item_idx] = current_step + logging.info(f"Forceful chunk end detected for item {item_idx} at local timestep {current_step}") + + def _should_terminate_loop( + self, + chunk_state: LongformChunkState, + chunk_end_dict: Dict[int, int], + end_of_text: List[bool], + batch_size: int, + ) -> bool: + """ + Check if all batch items have reached their end condition. + + Args: + chunk_state: Mutable state object tracking history across chunks. + chunk_end_dict: Maps batch indices to chunk end timesteps. + end_of_text: Whether text has ended for each batch item. + batch_size: Number of items in the batch. + + Returns: + True if all items have reached end, False otherwise. + """ + if len(chunk_state.end_indices) == batch_size: + logging.info("All ends reached") + return True + + completed_count = 0 + for bidx in range(batch_size): + if not end_of_text[bidx] and bidx in chunk_end_dict: + completed_count += 1 + elif end_of_text[bidx] and bidx in chunk_state.end_indices: + completed_count += 1 + + if completed_count == batch_size: + logging.info("All ends reached via chunk end") + return True + + return False + + def _run_longform_forward_with_cfg( + self, + context_tensors: Dict[str, Any], + audio_codes_embedded: torch.Tensor, + audio_codes_mask: torch.Tensor, + attn_prior: Any, + use_cfg: bool, + cfg_scale: float, + dummy_cond: Optional[torch.Tensor], + dummy_cond_mask: Optional[torch.Tensor], + dummy_additional_decoder_input: Optional[torch.Tensor], + dummy_addition_dec_mask: Optional[torch.Tensor], + batch_size: int, + ) -> Tuple[torch.Tensor, Any]: + """ + Run forward pass with optional classifier-free guidance. + + Args: + context_tensors: Context tensors from prepare_context_tensors. + audio_codes_embedded: Embedded audio codes. Shape: (B, T, E). + audio_codes_mask: Mask for audio codes. Shape: (B, T). + attn_prior: Attention prior tensor or list. + use_cfg: Whether to use classifier-free guidance. + cfg_scale: Scale factor for CFG. + dummy_cond: Dummy conditioning for unconditional branch. + dummy_cond_mask: Mask for dummy conditioning. + dummy_additional_decoder_input: Dummy additional decoder input. + dummy_addition_dec_mask: Mask for dummy additional input. + batch_size: Number of items in the batch. + + Returns: + Tuple of (logits, attention_probs). + """ + if use_cfg: + # Combine conditional and unconditional inputs + if isinstance(context_tensors.cond, list): + cfg_cond = [torch.cat([c, d], dim=0) for c, d in zip(context_tensors.cond, dummy_cond)] + cfg_cond_mask = [torch.cat([c, d], dim=0) for c, d in zip(context_tensors.cond_mask, dummy_cond_mask)] + else: + cfg_cond = torch.cat([context_tensors.cond, dummy_cond], dim=0) + cfg_cond_mask = torch.cat([context_tensors.cond_mask, dummy_cond_mask], dim=0) + + cfg_audio_embedded = torch.cat([audio_codes_embedded, audio_codes_embedded], dim=0) + cfg_audio_mask = torch.cat([audio_codes_mask, audio_codes_mask], dim=0) + + if dummy_additional_decoder_input is not None: + cfg_audio_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = ( + dummy_additional_decoder_input + ) + cfg_audio_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = dummy_addition_dec_mask + + combined_logits, attn_probs, _ = self.forward( + dec_input_embedded=cfg_audio_embedded, + dec_input_mask=cfg_audio_mask, + cond=cfg_cond, + cond_mask=cfg_cond_mask, + attn_prior=attn_prior, + multi_encoder_mapping=context_tensors.multi_encoder_mapping, + ) + + cond_logits = combined_logits[:batch_size] + uncond_logits = combined_logits[batch_size:] + all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits + else: + all_code_logits, attn_probs, _ = self.forward( + dec_input_embedded=audio_codes_embedded, + dec_input_mask=audio_codes_mask, + cond=context_tensors.cond, + cond_mask=context_tensors.cond_mask, + attn_prior=attn_prior, + multi_encoder_mapping=context_tensors.multi_encoder_mapping, + ) + + return all_code_logits, attn_probs + + def _initialize_longform_attn_prior( + self, + chunk_state: LongformChunkState, + current_chunk_len: torch.Tensor, + batch_text_lens: torch.Tensor, + max_text_len: int, + batch_size: int, + use_cfg: bool, + prior_epsilon: float, + device: torch.device, + ) -> Optional[torch.Tensor]: + """ + Initialize attention prior for longform generation with left offset tracking. + + This method constructs the initial attention prior when continuing from + previous chunks, accounting for the sliding window over text history. + + Args: + chunk_state: Mutable state object tracking history across chunks. + current_chunk_len: Length of the current text chunk for each batch item. + batch_text_lens: Text lengths for each batch item. + max_text_len: Maximum text length in the batch. + batch_size: Number of items in the batch. + use_cfg: Whether classifier-free guidance is being used. + prior_epsilon: Base epsilon value for attention prior. + device: Target device for tensors. + + Returns: + Attention prior tensor or None if no history exists. + """ + if len(chunk_state.previous_attn_len) == 0: + return None + + # Initialize prior tensor + cfg_multiplier = 2 if use_cfg else 1 + _attn_prior = torch.zeros(batch_size * cfg_multiplier, 1, max_text_len).to(device) + prior_epsilon + + for _idx in range(batch_size): + # Calculate left offset for sliding window + delta_in_len = self._to_int(current_chunk_len[_idx]) + len_to_delete = self._to_int(chunk_state.previous_attn_len[_idx] + delta_in_len - batch_text_lens[_idx]) + chunk_state.left_offset[_idx] = self._to_int(chunk_state.left_offset[_idx] + len_to_delete) + + # Skip if text has ended + if _idx in chunk_state.end_indices and chunk_state.end_indices[_idx] is not None: + continue + + # Set prior weights for new chunk + current_starting_point = batch_text_lens[_idx] - current_chunk_len[_idx] + prior_weights = self.longform_config.prior_weights_init + _attn_prior[_idx, :, :current_starting_point] = prior_epsilon * prior_epsilon + _attn_prior[_idx, :, current_starting_point] = prior_weights[0] + _attn_prior[_idx, :, current_starting_point + 1] = prior_weights[1] + _attn_prior[_idx, :, current_starting_point + 2] = prior_weights[2] + _attn_prior[_idx, :, current_starting_point + 3] = prior_weights[3] + _attn_prior[_idx, :, current_starting_point + 4] = prior_weights[4] + + return _attn_prior + + def _update_context_from_history( + self, + chunk_state: LongformChunkState, + context_tensors: Dict[str, Any], + current_chunk_len: torch.Tensor, + max_text_len: int, + beginning_of_text: bool, + batch_text_lens: torch.Tensor, + batch_size: int, + ) -> None: + """ + Update context tensors with cached history for longform generation. + + This method splices historical context embeddings into the current context + tensors to maintain continuity across text chunks. + + Args: + chunk_state: Mutable state object tracking history across chunks. + context_tensors: ContextTensorsOutput containing 'cond' tensor to update. + current_chunk_len: Length of the current text chunk for each batch item. + max_text_len: Maximum text length in the batch. + beginning_of_text: Whether this is the first chunk. + batch_text_lens: Text lengths for each batch item. + batch_size: Number of items in the batch. + """ + for _idx in range(batch_size): + # Skip if text has ended + if _idx in chunk_state.end_indices and chunk_state.end_indices[_idx] is not None: + continue + if not beginning_of_text: + pad_len_idx = max_text_len - batch_text_lens[_idx] + context_tensors.cond[_idx, : -current_chunk_len[_idx] - pad_len_idx] = ( + chunk_state.history_context_tensor[ + _idx, -(context_tensors.cond[_idx].shape[0] - current_chunk_len[_idx] - pad_len_idx) : + ] + ) + chunk_state.history_context_tensor = context_tensors.cond + + def _prepare_longform_text_tensors( + self, + chunk_state: LongformChunkState, + batch: Dict[str, torch.Tensor], + current_chunk_len: torch.Tensor, + beginning_of_text: bool, + device: torch.device, + ) -> Tuple[Dict[str, torch.Tensor], int]: + """ + Prepare text tensors with history for longform inference. + + This method handles the sliding window logic for text tokens, combining + historical text with new chunks and applying window size constraints. + + Args: + chunk_state: Mutable state object tracking history across chunks. + batch: Input batch containing 'text' and 'text_lens'. + current_chunk_len: Length of the current text chunk for each batch item. + beginning_of_text: Whether this is the first chunk. + device: Target device for tensors. + + Returns: + Tuple of (modified batch, max_text_len). + """ + batch_size = batch["text"].size(0) + text_tensors = [] + + for _idx in range(batch_size): + # If text has ended, use minimal placeholder + if _idx in chunk_state.end_indices and chunk_state.end_indices[_idx] is not None: + batch['text_lens'][_idx] = torch.tensor(1).to(device).long() + text_tensors.append(batch['text'][_idx]) + continue + + # Combine history with current chunk + if chunk_state.history_text is not None: + current_text = torch.cat( + [ + chunk_state.history_text[_idx][: chunk_state.history_text_lens[_idx]], + batch["text"][_idx][: current_chunk_len[_idx]], + ] + ) + else: + current_text = batch["text"][_idx][: current_chunk_len[_idx]] + + # Apply sliding window + history_len = min(current_chunk_len[_idx], self.longform_config.history_len_heuristic) + true_window_size = current_chunk_len[_idx] + history_len + if not beginning_of_text: + current_text = current_text[max(0, current_text.shape[0] - true_window_size) :] + + current_text_lens = current_text.shape[0] + text_tensors.append(current_text) + batch['text_lens'][_idx] = torch.tensor(current_text_lens).to(device).long() + + # Pad and stack text tensors + max_text_len = max(batch['text_lens']).item() + batch['text'] = stack_tensors(text_tensors, max_lens=[max_text_len]) + + # Update history + chunk_state.history_text = batch['text'] + chunk_state.history_text_lens = batch['text_lens'] + + return batch, max_text_len + + def generate_long_form_speech( + self, + batch, + chunk_state: LongformChunkState, + end_of_text, + beginning_of_text, + max_decoder_steps=2000, + temperature=0.7, + topk=80, + use_cfg=True, + cfg_scale=1.0, + estimate_alignment_from_layers: Optional[List[int]] = None, + lookahead_window_size=5, + apply_attention_prior=False, + apply_prior_to_layers: Optional[List[int]] = None, + prior_epsilon=1e-5, + eos_detection_method="argmax_or_multinomial_any", + ignore_finished_sentence_tracking=False, + ): + """ + Generates speech for long-form text by progressively shifting through text tokens. + + This method processes long text inputs by generating a fixed number of audio tokens per text token, + then shifting to the next text token. It maintains a sliding window over text and audio histories, + tracking how many audio tokens were generated for each text position. + + Args: + batch (dict): Input batch containing 'text' and 'text_lens'. + chunk_state (LongformChunkState): Mutable state object tracking history across chunks. + Created via model.create_longform_chunk_state() and updated in-place. + end_of_text (List[bool]): Whether entire text has been provided for each batch item. + beginning_of_text (bool): Whether this is the first chunk. + max_decoder_steps (int): Maximum total audio tokens to generate. + temperature (float): Sampling temperature for audio code generation. + topk (int): Top-k sampling parameter. + use_cfg (bool): Whether to use classifier-free guidance. + cfg_scale (float): CFG scale factor. + estimate_alignment_from_layers (list, optional): Layers to use for alignment estimation. + lookahead_window_size (int): Forward-looking window size for attention prior. + apply_attention_prior (bool): Whether to apply attention prior. + apply_prior_to_layers (list, optional): Layers to apply prior to. + prior_epsilon (float): Base prior probability for non-targeted positions. + eos_detection_method (str): Method for detecting end-of-sequence. + ignore_finished_sentence_tracking (bool): Whether to ignore finished sentence tracking. + + Returns: + InferBatchOutput: Contains predicted_codes, predicted_codes_lens, and empty audio fields. + """ + eos_detection_method = EOSDetectionMethod(eos_detection_method) + device = batch['text'].device + with torch.no_grad(): + current_chunk_len = copy.deepcopy(batch['text_lens'].detach()) + batch_size = batch["text"].size(0) + + # Prepare text tensors with history + batch, max_text_len = self._prepare_longform_text_tensors( + chunk_state, batch, current_chunk_len, beginning_of_text, device + ) + context_tensors = self.prepare_context_tensors(batch) + + # Update context with historical embeddings + self._update_context_from_history( + chunk_state, + context_tensors, + current_chunk_len, + max_text_len, + beginning_of_text, + batch['text_lens'], + batch_size, + ) + + audio_codes_input = ( + torch.full((batch_size, self.num_audio_codebooks, 1), self.audio_bos_id).long().to(device) + ) + audio_codes_lens = torch.full((batch_size,), audio_codes_input.size(2), device=device).long().to(device) + audio_codes_mask = get_mask_from_lengths(audio_codes_lens) + + # Initialize dummy variables for CFG + dummy_cond = None + dummy_cond_mask = None + dummy_additional_decoder_input = None + dummy_addition_dec_mask = None + if use_cfg: + dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = ( + self.prepare_dummy_cond_for_cfg( + context_tensors.cond, + context_tensors.cond_mask, + context_tensors.additional_decoder_input, + context_tensors.additional_decoder_mask, + ) + ) + + # Initialize attention prior for longform generation + initial_attn_prior = self._initialize_longform_attn_prior( + chunk_state, + current_chunk_len, + batch['text_lens'], + max_text_len, + batch_size, + use_cfg, + prior_epsilon, + device, + ) + chunk_state.previous_attn_len = copy.deepcopy(batch['text_lens'].detach().tolist()) + + # Create decoder state object to track all local mutable state + state = LongformDecoderState( + audio_codes_input=audio_codes_input, + audio_codes_lens=audio_codes_lens, + audio_codes_mask=audio_codes_mask, + attended_timestep_counter=[{} for _ in range(batch_size)], + all_predictions=[], + chunk_end_dict={}, + unfinished_texts={}, + finished_texts_counter={}, + attn_prior=initial_attn_prior, + ) + + for idx in range(max_decoder_steps): + if idx % 30 == 0: + logging.info(f"Longform decoding timestep {idx}") + + # Embed audio codes and concatenate with additional decoder input + audio_codes_embedded = self.embed_audio_tokens(state.audio_codes_input) + if context_tensors.additional_decoder_input is not None: + _audio_codes_embedded = torch.cat( + [context_tensors.additional_decoder_input, audio_codes_embedded], dim=1 + ) + _audio_codes_mask = torch.cat( + [context_tensors.additional_decoder_mask, state.audio_codes_mask], dim=1 + ) + else: + _audio_codes_embedded = audio_codes_embedded + _audio_codes_mask = state.audio_codes_mask + + # Prepare attention prior for layers + if apply_prior_to_layers is not None: + attn_prior = [None for _ in range(self.cfg.decoder.n_layers)] + for layer_idx in apply_prior_to_layers: + attn_prior[layer_idx] = state.attn_prior + else: + attn_prior = state.attn_prior + + if self.model_type == 'multi_encoder_context_tts': + attn_prior = [attn_prior, None] + + # Run forward pass with optional CFG + all_code_logits, attn_probs = self._run_longform_forward_with_cfg( + context_tensors=context_tensors, + audio_codes_embedded=_audio_codes_embedded, + audio_codes_mask=_audio_codes_mask, + attn_prior=attn_prior, + use_cfg=use_cfg, + cfg_scale=cfg_scale, + dummy_cond=dummy_cond, + dummy_cond_mask=dummy_cond_mask, + dummy_additional_decoder_input=dummy_additional_decoder_input, + dummy_addition_dec_mask=dummy_addition_dec_mask, + batch_size=batch_size, + ) + + if apply_attention_prior: + # Get cross-attention scores (optionally from specific layers for alignment) + alignment_attention_scores, _ = self.get_cross_attention_scores( + attn_probs, filter_layers=estimate_alignment_from_layers + ) # B, text_timesteps + + text_time_step_attended, state.attended_timestep_counter = self.get_most_attended_text_timestep( + alignment_attention_scores=alignment_attention_scores, + last_attended_timesteps=chunk_state.last_attended_timesteps, + text_lens=context_tensors.text_lens, + lookahead_window_size=lookahead_window_size, + attended_timestep_counter=state.attended_timestep_counter, + batch_size=batch_size, + left_offset=chunk_state.left_offset, + ) + chunk_state.last_attended_timesteps.append( + text_time_step_attended.detach() + if isinstance(text_time_step_attended, torch.Tensor) + else text_time_step_attended + ) + + (state.attn_prior, state.unfinished_texts, state.finished_texts_counter) = ( + self.construct_longform_inference_prior( + prior_epsilon=prior_epsilon, + cross_attention_scores=alignment_attention_scores, + text_lens=context_tensors.text_lens, + text_time_step_attended=text_time_step_attended, + attended_timestep_counter=state.attended_timestep_counter, + unfinished_texts=state.unfinished_texts, + finished_texts_counter=state.finished_texts_counter, + end_indices=chunk_state.end_indices, + chunk_end_dict=state.chunk_end_dict, + batch_size=batch_size, + left_offset=chunk_state.left_offset, + ) + ) + + for key in state.finished_texts_counter: + state.finished_texts_counter[key] += 1 + limit = ( + self.longform_config.finished_limit_with_eot + if end_of_text[key] + else self.longform_config.finished_limit_without_eot + ) + if state.finished_texts_counter[key] > limit: + # We should allow EOS to be predicted now. + state.unfinished_texts[key] = False + + if ignore_finished_sentence_tracking: + finished_items = {} + unfinished_items = {} + else: + finished_items = { + k: v + for k, v in state.finished_texts_counter.items() + if v >= self.longform_config.finished_limit_with_eot + } + unfinished_items = {k: v for k, v in state.unfinished_texts.items() if v} + + all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) + audio_codes_next = self.sample_codes_from_logits( + all_code_logits_t, + temperature=temperature, + topk=topk, + unfinished_items=unfinished_items, + finished_items=finished_items, + ) # (B, num_codebooks) + all_codes_next_argmax = self.sample_codes_from_logits( + all_code_logits_t, + temperature=self.longform_config.argmax_temperature, + topk=1, + unfinished_items=unfinished_items, + finished_items=finished_items, + ) # (B, num_codebooks) + + # Check for EOS and update state + self._check_eos_and_update_state( + chunk_state, + audio_codes_next, + all_codes_next_argmax, + state.chunk_end_dict, + state.finished_texts_counter, + end_of_text, + eos_detection_method, + idx, + batch_size, + ) + + state.all_predictions.append(audio_codes_next) + + state.audio_codes_input = torch.cat([state.audio_codes_input, audio_codes_next], dim=-1) # (B, C, T') + state.audio_codes_lens = state.audio_codes_lens + 1 + state.audio_codes_mask = get_mask_from_lengths(state.audio_codes_lens) + + # Check termination condition + if self._should_terminate_loop(chunk_state, state.chunk_end_dict, end_of_text, batch_size): + break + + chunk_state.overall_idx += 1 + + predicted_codes = torch.stack(state.all_predictions, dim=-1) + predicted_codes = predicted_codes.squeeze(2) + predicted_codes_lens = torch.tensor( + [state.chunk_end_dict.get(item_idx, predicted_codes.size(-1)) for item_idx in range(batch_size)], + device=device, + ) + + return InferBatchOutput( + predicted_audio=torch.empty(0, device=device), + predicted_audio_lens=torch.empty(0, device=device, dtype=torch.long), + predicted_codes=predicted_codes, + predicted_codes_lens=predicted_codes_lens, + rtf_metrics={}, + cross_attention_maps=[], + headwise_cross_attention_maps=[], + ) diff --git a/nemo/collections/tts/modules/magpietts_inference/inference.py b/nemo/collections/tts/modules/magpietts_inference/inference.py index f571681c57a6..f7809fdb7ded 100644 --- a/nemo/collections/tts/modules/magpietts_inference/inference.py +++ b/nemo/collections/tts/modules/magpietts_inference/inference.py @@ -17,6 +17,7 @@ This module provides: - InferenceConfig: Dataclass for inference hyperparameters - MagpieInferenceRunner: Class for running batch inference with a loaded model + (supports auto-detection of longform text via longform_mode="auto") """ from __future__ import annotations @@ -25,15 +26,17 @@ import shutil import time from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import soundfile as sf import torch from PIL import Image +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import AggregatedTTSTokenizer, IPATokenizer -from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset +from nemo.collections.tts.data.text_to_speech_dataset import LongFormTTSInferenceDataset, MagpieTTSDataset from nemo.collections.tts.models import MagpieTTSModel +from nemo.collections.tts.parts.utils.tts_dataset_utils import stack_tensors from nemo.utils import logging @@ -67,6 +70,10 @@ class InferenceConfig: # EOS detection eos_detection_method: Method for detecting end-of-sequence. ignore_finished_sentence_tracking: Whether to ignore sentence tracking. + + # Longform inference mode + longform_mode: Longform inference mode ("auto", "always", "never"). + longform_word_threshold: Word threshold for auto-detection. """ # Core sampling parameters @@ -96,6 +103,10 @@ class InferenceConfig: eos_detection_method: str = "argmax_or_multinomial_any" ignore_finished_sentence_tracking: bool = False + # Longform inference mode + longform_mode: str = "auto" # "auto" | "always" | "never" + longform_word_threshold: int = 40 # Word threshold for auto-detection + def build_identifier(self) -> str: """Build a unique identifier string for this configuration. @@ -166,6 +177,11 @@ def __init__( # Set phoneme probability to 1 for inference self._configure_tokenizer() + # Cached state from create_dataset (set when create_dataset is called) + self._use_longform: Optional[bool] = None + self._manifest_records: Optional[List[dict]] = None + self._audio_base_dir: Optional[str] = None + def _configure_tokenizer(self) -> None: """Configure the tokenizer for inference (phoneme prob = 1.0).""" g2p = None @@ -177,21 +193,52 @@ def _configure_tokenizer(self) -> None: if g2p is not None: g2p.phoneme_probability = 1.0 + def _needs_longform_inference(self, manifest_records: List[dict]) -> bool: + """Determine if any manifest entry needs longform inference. + + Checks if text exceeds character threshold OR has multiple sentences. + + Args: + manifest_records: List of manifest record dictionaries. + + Returns: + True if longform inference should be used, False otherwise. + """ + if self.config.longform_mode == "always": + return True + if self.config.longform_mode == "never": + return False + + # Auto-detection based on text characteristics + for record in manifest_records: + text = record.get('text', '') + + # Check word count + word_count = len(text.split()) + if word_count >= self.config.longform_word_threshold: + return True + + return False + def create_dataset( self, dataset_meta: dict, context_duration_min: Optional[float] = None, context_duration_max: Optional[float] = None, - ) -> MagpieTTSDataset: + ) -> Union[MagpieTTSDataset, LongFormTTSInferenceDataset]: """Create a dataset for inference. + Automatically creates the appropriate dataset type based on longform detection: + - LongFormTTSInferenceDataset if longform text is detected + - MagpieTTSDataset for standard inference + Args: - dataset_meta: Dataset metadata dictionary. + dataset_meta: Dataset metadata dictionary with 'manifest_path' and 'audio_dir'. context_duration_min: Minimum context duration (uses model default if None). context_duration_max: Maximum context duration (uses model default if None). Returns: - Configured MagpieTTSDataset instance. + Configured dataset instance (MagpieTTSDataset or LongFormTTSInferenceDataset). """ # Use model defaults if not specified if context_duration_min is None: @@ -204,37 +251,106 @@ def create_dataset( context_duration_min = 5.0 context_duration_max = 5.0 - dataset = MagpieTTSDataset( - dataset_meta=dataset_meta, - sample_rate=self.model.sample_rate, - min_duration=0.5, - max_duration=20, - codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, - bos_id=self.model.bos_id, - eos_id=self.model.eos_id, - context_audio_bos_id=self.model.context_audio_bos_id, - context_audio_eos_id=self.model.context_audio_eos_id, - audio_bos_id=self.model.audio_bos_id, - audio_eos_id=self.model.audio_eos_id, - num_audio_codebooks=self.model.num_audio_codebooks, - prior_scaling_factor=None, - load_cached_codes_if_available=False, - dataset_type='test', - tokenizer_config=None, - load_16khz_audio=self.model.model_type == 'single_encoder_sv_tts', - use_text_conditioning_tokenizer=self.model.use_text_conditioning_encoder, - text_conditioning_tokenizer_name=self.model.text_conditioning_tokenizer_name, - pad_context_text_to_max_duration=self.model.pad_context_text_to_max_duration, - context_duration_min=context_duration_min, - context_duration_max=context_duration_max, - ) - - # Attach model's tokenizer - dataset.text_tokenizer = self.model.tokenizer + # Read manifest and cache for later use + dataset_name = list(dataset_meta.keys())[0] + dataset_info = dataset_meta[dataset_name] + manifest_path = dataset_info.get('manifest_path') + audio_dir = dataset_info.get('audio_dir', '') + logging.info(f"Dataset name: {dataset_name}, manifest_path: {manifest_path}, audio_dir: {audio_dir}") + + self._manifest_records = read_manifest(manifest_path) + self._audio_base_dir = audio_dir + # Determine longform mode and cache + self._use_longform = self._needs_longform_inference(self._manifest_records) + logging.info(f"Longform detection: {self._use_longform} (mode: {self.config.longform_mode})") + + # Create appropriate dataset type based on longform detection + if self._use_longform: + logging.info("Creating LongFormTTSInferenceDataset for longform inference") + dataset = self._create_longform_dataset(dataset_meta, context_duration_min, context_duration_max) + else: + logging.info("Creating MagpieTTSDataset for standard inference") + dataset = MagpieTTSDataset( + dataset_meta=dataset_meta, + sample_rate=self.model.sample_rate, + min_duration=0.5, + max_duration=20, + codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, + bos_id=self.model.bos_id, + eos_id=self.model.eos_id, + context_audio_bos_id=self.model.context_audio_bos_id, + context_audio_eos_id=self.model.context_audio_eos_id, + audio_bos_id=self.model.audio_bos_id, + audio_eos_id=self.model.audio_eos_id, + num_audio_codebooks=self.model.num_audio_codebooks, + prior_scaling_factor=None, + load_cached_codes_if_available=False, + dataset_type='test', + tokenizer_config=None, + load_16khz_audio=self.model.model_type == 'single_encoder_sv_tts', + use_text_conditioning_tokenizer=self.model.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.model.text_conditioning_tokenizer_name, + pad_context_text_to_max_duration=self.model.pad_context_text_to_max_duration, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, + ) + # Attach model's tokenizer for standard dataset + dataset.text_tokenizer = self.model.tokenizer return dataset def run_inference_on_dataset( + self, + dataset: Union[MagpieTTSDataset, LongFormTTSInferenceDataset], + output_dir: str, + manifest_records: Optional[List[dict]] = None, + audio_base_dir: Optional[str] = None, + save_cross_attention_maps: bool = True, + save_context_audio: bool = True, + ) -> Tuple[List[dict], List[str]]: + """Run inference on a dataset. + + Routes to standard or longform inference based on the cached detection + from create_dataset(). Uses cached manifest_records and audio_base_dir + if not provided. + + Args: + dataset: The inference dataset (created by create_dataset()). + output_dir: Directory to save generated audio and artifacts. + manifest_records: Original manifest records (uses cached if None). + audio_base_dir: Base directory for audio paths (uses cached if None). + save_cross_attention_maps: Whether to save attention map images. + save_context_audio: Whether to copy context audio files. + + Returns: + Tuple of: + - rtf_metrics: List of real-time factor metrics per batch. + - generated_audio_paths: List of paths to generated audio files. + """ + # Use cached values if not provided + if manifest_records is None: + if self._manifest_records is None: + raise ValueError("manifest_records not provided and not cached from create_dataset()") + manifest_records = self._manifest_records + + if audio_base_dir is None: + if self._audio_base_dir is None: + raise ValueError("audio_base_dir not provided and not cached from create_dataset()") + audio_base_dir = self._audio_base_dir + + # Route based on cached longform detection + if self._use_longform: + logging.info("Using longform inference path") + return self._run_longform_inference( + dataset, output_dir, manifest_records, audio_base_dir, save_context_audio + ) + else: + logging.info("Using standard inference path") + return self._run_standard_inference( + dataset, output_dir, manifest_records, audio_base_dir, save_cross_attention_maps, save_context_audio + ) + + def _run_standard_inference( self, dataset: MagpieTTSDataset, output_dir: str, @@ -243,7 +359,7 @@ def run_inference_on_dataset( save_cross_attention_maps: bool = True, save_context_audio: bool = True, ) -> Tuple[List[dict], List[str]]: - """Run inference on a dataset and save outputs. + """Run standard single-pass inference on a dataset. Args: dataset: The inference dataset. @@ -400,3 +516,268 @@ def compute_mean_rtf_metrics(rtf_metrics_list: List[dict]) -> Dict[str, float]: mean_metrics[key] = float(sum(values) / len(values)) if values else 0.0 return mean_metrics + + def _create_longform_dataset( + self, + dataset_meta: dict, + context_duration_min: Optional[float] = None, + context_duration_max: Optional[float] = None, + ) -> LongFormTTSInferenceDataset: + """Create a longform dataset for inference. + + Args: + dataset_meta: Dataset metadata dictionary (same format as MagpieTTSDataset). + context_duration_min: Minimum context duration (uses model default if None). + context_duration_max: Maximum context duration (uses model default if None). + + Returns: + Configured LongFormTTSInferenceDataset instance. + """ + # Use model defaults if not specified + if context_duration_min is None: + context_duration_min = self.model.cfg.get('context_duration_min', 5.0) + if context_duration_max is None: + context_duration_max = self.model.cfg.get('context_duration_max', 5.0) + + # For multi-encoder models, use fixed 5s context for fair evaluation + if context_duration_min < 5.0 and context_duration_max > 5.0: + context_duration_min = 5.0 + context_duration_max = 5.0 + + # Determine tokenizer name + tokenizer_name = "english_phoneme" + if isinstance(self.model.tokenizer, AggregatedTTSTokenizer): + tokenizer_name = "english_phoneme" + + # Create dataset - inherits from MagpieTTSDataset, so uses same dataset_meta format + dataset = LongFormTTSInferenceDataset( + dataset_meta=dataset_meta, + sample_rate=self.model.sample_rate, + tokenizer_name=tokenizer_name, + codec_model_samples_per_frame=self.model.codec_model_samples_per_frame, + eos_id=self.model.eos_id, + audio_bos_id=self.model.audio_bos_id, + audio_eos_id=self.model.audio_eos_id, + context_audio_bos_id=self.model.context_audio_bos_id, + context_audio_eos_id=self.model.context_audio_eos_id, + num_audio_codebooks=self.model.num_audio_codebooks, + context_duration_min=context_duration_min, + context_duration_max=context_duration_max, + use_text_conditioning_tokenizer=self.model.use_text_conditioning_encoder, + text_conditioning_tokenizer_name=self.model.text_conditioning_tokenizer_name, + pad_context_text_to_max_duration=self.model.pad_context_text_to_max_duration, + load_16khz_audio=self.model.model_type == 'single_encoder_sv_tts', + ) + + # Attach model's tokenizer + dataset.text_tokenizer = self.model.tokenizer + + return dataset + + def _run_longform_inference( + self, + dataset: LongFormTTSInferenceDataset, + output_dir: str, + manifest_records: List[dict], + audio_base_dir: str, + save_context_audio: bool = True, + ) -> Tuple[List[dict], List[str]]: + """Run longform inference with automatic sentence chunking. + + Processes text sentence-by-sentence using generate_long_form_speech(). + + Args: + dataset: LongFormTTSInferenceDataset created by create_dataset(). + output_dir: Directory to save generated audio and artifacts. + manifest_records: List of manifest record dictionaries. + audio_base_dir: Base directory for resolving audio paths. + save_context_audio: Whether to copy context audio files. + + Returns: + Tuple of: + - rtf_metrics: List of real-time factor metrics per batch. + - generated_audio_paths: List of paths to generated audio files. + """ + os.makedirs(output_dir, exist_ok=True) + self._delete_old_generated_files(output_dir) + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.config.batch_size, + collate_fn=dataset.collate_fn, + num_workers=0, # Avoid multiprocessing issues with CUDA + shuffle=False, + ) + + all_rtf_metrics = [] + generated_audio_paths = [] + global_item_idx = 0 + + for batch_idx, batch in enumerate(dataloader): + logging.info(f"Processing batch {batch_idx + 1}/{len(dataloader)} (longform)") + + # Move batch tensors to CUDA + batch = self._batch_to_cuda(batch) + + batch_size = len(batch['chunked_tokens']) + max_num_chunks = max(len(tokens) for tokens in batch['chunked_tokens']) + + # Create longform chunk state for this batch + chunk_state = self.model.create_longform_chunk_state(batch_size=batch_size) + + # Accumulators for predicted codes + predicted_codes_per_sample = [[] for _ in range(batch_size)] + predicted_codes_lens = [0 for _ in range(batch_size)] + + start_time = time.time() + + # Iterate over text chunks (sentences) + for chunk_idx in range(max_num_chunks): + # Extract current chunk tokens for each sample + current_tokens = [] + current_tokens_lens = [] + for b_idx in range(batch_size): + current_tokens.append(batch['chunked_tokens'][b_idx][chunk_idx]) + current_tokens_lens.append(batch['chunked_tokens_lens'][b_idx][chunk_idx]) + + # Pad tokens to max length in this chunk + max_len = max(current_tokens_lens) + batch['text'] = stack_tensors(current_tokens, max_lens=[max_len]).cuda() + batch['text_lens'] = torch.tensor(current_tokens_lens, dtype=torch.int32).cuda() + + # Compute is_end_of_text flags + is_end_of_text = self._compute_end_of_text_flags( + batch, chunk_idx, max_num_chunks, current_tokens_lens, batch_size + ) + + beginning_of_text = chunk_idx == 0 + + # Call generate_long_form_speech + output = self.model.generate_long_form_speech( + batch, + chunk_state=chunk_state, + end_of_text=is_end_of_text, + beginning_of_text=beginning_of_text, + max_decoder_steps=self.config.max_decoder_steps, + temperature=self.config.temperature, + topk=self.config.topk, + use_cfg=self.config.use_cfg, + cfg_scale=self.config.cfg_scale, + apply_attention_prior=self.config.apply_attention_prior, + prior_epsilon=self.config.attention_prior_epsilon, + lookahead_window_size=self.config.attention_prior_lookahead_window, + estimate_alignment_from_layers=self.config.estimate_alignment_from_layers, + apply_prior_to_layers=self.config.apply_prior_to_layers, + eos_detection_method=self.config.eos_detection_method, + ignore_finished_sentence_tracking=self.config.ignore_finished_sentence_tracking, + ) + + # Unpack output - generate_long_form_speech returns InferBatchOutput + chunk_codes = output.predicted_codes + chunk_codes_lens = output.predicted_codes_lens + + # Accumulate codes for each sample + for b_idx in range(batch_size): + # Skip if this sample's text has ended (padding chunks) + if is_end_of_text[b_idx] and current_tokens_lens[b_idx] == 1: + continue + + code_len = chunk_codes_lens[b_idx] + if code_len > 0: + codes_slice = chunk_codes[b_idx][:, :code_len] + predicted_codes_per_sample[b_idx].append(codes_slice) + predicted_codes_lens[b_idx] += code_len + + elapsed = time.time() - start_time + logging.info(f"Batch longform inference time: {elapsed:.2f}s") + + # Concatenate codes and convert to audio + predicted_codes_list = [] + for b_idx in range(batch_size): + if predicted_codes_per_sample[b_idx]: + concatenated = torch.cat(predicted_codes_per_sample[b_idx], dim=1).cuda() + else: + # Empty placeholder + concatenated = torch.zeros((self.model.num_audio_codebooks, 1), dtype=torch.long, device='cuda') + predicted_codes_list.append(concatenated) + + # Stack and convert to audio + max_code_len = max(predicted_codes_lens) if any(predicted_codes_lens) else 1 + predicted_codes = stack_tensors(predicted_codes_list, max_lens=[max_code_len]).cuda() + predicted_codes_lens_tensor = torch.tensor(predicted_codes_lens, dtype=torch.long, device='cuda') + + predicted_audio, predicted_audio_lens = self.model.codes_to_audio( + predicted_codes, predicted_codes_lens_tensor + ) + + # Compute RTF metrics + total_audio_samples = sum(predicted_audio_lens.cpu().tolist()) + total_audio_seconds = total_audio_samples / self.model.sample_rate + rtf = elapsed / total_audio_seconds if total_audio_seconds > 0 else 0.0 + rtf_metrics = { + 'inference_time': elapsed, + 'audio_seconds': total_audio_seconds, + 'rtf': rtf, + } + all_rtf_metrics.append(rtf_metrics) + + # Save outputs + predicted_audio_np = predicted_audio.float().detach().cpu().numpy() + + for b_idx in range(batch_size): + sample_idx = batch['idx'][b_idx] + audio_len = predicted_audio_lens[b_idx].item() + audio_np = predicted_audio_np[b_idx, :audio_len] + + audio_path = os.path.join(output_dir, f"predicted_audio_{sample_idx}.wav") + sf.write(audio_path, audio_np, self.model.sample_rate) + generated_audio_paths.append(audio_path) + + # Copy reference audio if requested + if save_context_audio and sample_idx < len(manifest_records): + self._copy_reference_audio( + manifest_records[sample_idx], + audio_base_dir, + output_dir, + sample_idx, + ) + + global_item_idx += 1 + + return all_rtf_metrics, generated_audio_paths + + def _compute_end_of_text_flags( + self, + batch: Dict[str, Any], + chunk_idx: int, + max_num_chunks: int, + current_tokens_lens: List[int], + batch_size: int, + ) -> List[bool]: + """Compute end-of-text flags for each sample in batch. + + Args: + batch: Current batch dictionary. + chunk_idx: Current chunk index. + max_num_chunks: Maximum number of chunks in this batch. + current_tokens_lens: Token lengths for current chunk per sample. + batch_size: Number of samples in batch. + + Returns: + List of booleans indicating if each sample has reached end of text. + """ + is_end_of_text = [] + for b_idx in range(batch_size): + if chunk_idx == max_num_chunks - 1: + # Last chunk + is_end_of_text.append(True) + elif current_tokens_lens[b_idx] == 1: + # Current chunk is padding + is_end_of_text.append(True) + elif batch['chunked_tokens_lens'][b_idx][chunk_idx + 1] == 1: + # Next chunk is padding + is_end_of_text.append(True) + else: + is_end_of_text.append(False) + + return is_end_of_text diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 96806f633a54..700f1fa3a2bf 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -358,3 +358,102 @@ def sample_audio( audio = normalize_volume(audio) return audio, audio_filepath_abs, audio_filepath_rel + + +def split_by_sentence( + paragraph: str, + sentence_separators: Optional[List[str]] = None, +) -> List[str]: + """ + Split a paragraph into sentences based on sentence-ending punctuation. + + Handles edge cases like abbreviations (e.g., "Dr.", "Mr.", "a.m.") by checking + if the separator is followed by a space before splitting. Sentence-ending + punctuation is preserved with each sentence. + + Args: + paragraph: The input text paragraph to split into sentences. + sentence_separators: A list of strings representing sentence-ending + punctuation marks. Defaults to ['.', '?', '!', '...']. + + Returns: + List of sentence strings with punctuation preserved. + + Examples: + >>> split_by_sentence("Hello world. How are you?") + ["Hello world.", "How are you?"] + + >>> split_by_sentence("Dr. Smith is here. Good morning!") + ["Dr. Smith is here.", "Good morning!"] + """ + if sentence_separators is None: + sentence_separators = ['.', '?', '!', '...'] + + if not paragraph or not paragraph.strip(): + return [] + + # Normalize text: replace hyphens with spaces, remove asterisks + paragraph = paragraph.replace('-', ' ') + paragraph = paragraph.replace('*', '') + + sentences = [] + last_sep_idx = -1 + + for i, char in enumerate(paragraph): + # Check if current char is a separator and next char is a space + # This avoids splitting abbreviations like "Dr." or "a.m." + next_char = paragraph[i + 1] if i + 1 < len(paragraph) else "" + if char in sentence_separators and next_char == " ": + sentences.append(paragraph[last_sep_idx + 1 : i + 1].strip()) + last_sep_idx = i + 1 + + # Add remaining text as the last sentence + if last_sep_idx < len(paragraph): + remaining = paragraph[last_sep_idx + 1 :].strip() + if remaining: + sentences.append(remaining) + + # Remove empty sentences and capitalize first letter + sentences = [sent for sent in sentences if len(sent) > 0] + sentences = [sent if sent[0].isupper() else sent[0].upper() + sent[1:] for sent in sentences if sent] + + return sentences + + +def chunk_and_tokenize_text_by_sentence( + text: str, + tokenizer_name: str, + text_tokenizer: Any, + eos_token_id: int, +) -> Tuple[List[torch.Tensor], List[int], List[str]]: + """ + Tokenize text split by sentences, adding EOS token after each sentence. + + Args: + text: Input text to tokenize. + tokenizer_name: Name of the tokenizer to use (e.g., "english_phoneme"). + text_tokenizer: The tokenizer instance. + eos_token_id: End-of-sequence token ID to append. + + Returns: + Tuple of: + - chunked_tokens: List of token tensors, one per sentence. + - chunked_tokens_len: List of token lengths. + - chunked_text: List of sentence strings. + """ + split_sentences = split_by_sentence(text) + + chunked_tokens = [] + chunked_tokens_len = [] + chunked_text = [] + + for sentence in split_sentences: + chunked_text.append(sentence) + tokens = text_tokenizer.encode(text=sentence, tokenizer_name=tokenizer_name) + tokens = tokens + [eos_token_id] + tokens = torch.tensor(tokens, dtype=torch.int32) + tokens_len = tokens.shape[0] + chunked_tokens.append(tokens) + chunked_tokens_len.append(tokens_len) + + return chunked_tokens, chunked_tokens_len, chunked_text diff --git a/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh new file mode 100644 index 000000000000..f73f21ace31a --- /dev/null +++ b/tests/functional_tests/L2_TTS_InferEvaluatelongform_Magpietts_ZeroShot.sh @@ -0,0 +1,33 @@ +# Copyright (c) 2020-2025, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +coverage run -a --data-file=/workspace/.coverage --source=/workspace/nemo examples/tts/magpietts_inference.py \ + --codecmodel_path /home/TestData/tts/AudioCodec_21Hz_no_eliz_without_wavlm_disc.nemo \ + --datasets_json_path examples/tts/evalset_config.json \ + --out_dir ./mplf_zs_0 \ + --batch_size 4 \ + --longform_mode always \ + --longform_max_decoder_steps 50000 \ + --use_cfg \ + --cfg_scale 2.5 \ + --num_repeats 1 \ + --temperature 0.6 \ + --hparams_files /home/TestData/tts/2506_ZeroShot/lrhm_short_yt_prioralways_alignement_0.002_priorscale_0.1.yaml \ + --checkpoint_files /home/TestData/tts/2506_ZeroShot/dpo-T5TTS--val_loss=0.4513-epoch=3.ckpt \ + --legacy_codebooks \ + --legacy_text_conditioning \ + --apply_attention_prior \ + --run_evaluation \ + --clean_up_disk \ + --cer_target 0.25 \ + --ssim_target 0.7