diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md b/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md new file mode 100644 index 0000000000..b167e3d0f8 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/IMPLEMENTATION.md @@ -0,0 +1,269 @@ +# NVIDIA Riva TTS Extension - Implementation Details + +## Overview + +This document describes the implementation of the NVIDIA Riva TTS extension for TEN Framework. The extension provides high-quality, GPU-accelerated text-to-speech synthesis using NVIDIA Riva Speech Skills. + +## Architecture + +### Component Structure + +``` +nvidia_riva_tts_python/ +├── extension.py # Main extension class +├── riva_tts.py # Riva client implementation +├── config.py # Configuration model +├── addon.py # Extension registration +├── manifest.json # Extension metadata +├── property.json # Default properties +├── requirements.txt # Python dependencies +├── README.md # User documentation +└── tests/ # Test suite + ├── test_config.py + └── test_extension.py +``` + +### Class Hierarchy + +``` +AsyncTTSExtension (base class from ten_ai_base) + └── NvidiaRivaTTSExtension + └── uses NvidiaRivaTTSClient + └── uses riva.client.SpeechSynthesisService +``` + +## Implementation Details + +### 1. Extension Class (`extension.py`) + +The `NvidiaRivaTTSExtension` class inherits from `AsyncTTSExtension` and implements the required abstract methods: + +- **`create_config()`**: Parses JSON configuration into `NvidiaRivaTTSConfig` +- **`create_client()`**: Instantiates `NvidiaRivaTTSClient` with configuration +- **`vendor()`**: Returns "nvidia_riva" as the vendor identifier +- **`synthesize_audio_sample_rate()`**: Returns the configured sample rate + +### 2. Client Implementation (`riva_tts.py`) + +The `NvidiaRivaTTSClient` class handles the actual TTS synthesis: + +#### Initialization +- Creates Riva Auth object with server URI and SSL settings +- Initializes `SpeechSynthesisService` for TTS operations +- Validates server connectivity + +#### Synthesis Method +```python +async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes] +``` + +**Flow:** +1. Validates input text (non-empty) +2. Calls `tts_service.synthesize_online()` for streaming synthesis +3. Iterates through audio chunks from Riva +4. Converts audio data to PCM bytes +5. Yields audio chunks for streaming playback +6. Handles cancellation requests + +**Key Features:** +- Streaming synthesis for low latency +- Cancellation support via `_is_cancelled` flag +- Comprehensive logging at each step +- Error handling with detailed messages + +### 3. Configuration (`config.py`) + +The `NvidiaRivaTTSConfig` class extends `AsyncTTSConfig`: + +**Required Parameters:** +- `server`: Riva server address (host:port) +- `language_code`: Language identifier (e.g., "en-US") +- `voice_name`: Voice identifier (e.g., "English-US.Female-1") + +**Optional Parameters:** +- `sample_rate`: Audio sample rate in Hz (default: 16000) +- `use_ssl`: Enable SSL for gRPC (default: false) + +**Validation:** +- Ensures all required parameters are present +- Validates parameter types and formats + +### 4. Addon Registration (`addon.py`) + +Registers the extension with TEN Framework using the `@register_addon_as_extension` decorator. + +## Integration with TEN Framework + +### TTS Interface Compliance + +The extension implements the standard TEN Framework TTS interface defined in `ten_ai_base/api/tts-interface.json`: + +- **Input**: Text data via TEN data messages +- **Output**: PCM audio frames via TEN audio frame messages +- **Control**: Start/stop/cancel commands via TEN commands + +### Message Flow + +``` +1. Text Input → Extension receives text data +2. Configuration → Loads voice, language, sample rate +3. Synthesis → Calls Riva API with streaming +4. Audio Output → Yields PCM audio chunks +5. Completion → Signals end of synthesis +``` + +## NVIDIA Riva Integration + +### gRPC API Usage + +The extension uses the official `nvidia-riva-client` Python package which provides: + +- **Auth**: Authentication and connection management +- **SpeechSynthesisService**: TTS API wrapper +- **AudioEncoding**: Audio format specifications + +### Streaming vs Batch + +The implementation uses **streaming synthesis** (`synthesize_online`) for: +- Lower latency (first audio chunk arrives quickly) +- Better user experience in real-time applications +- Efficient memory usage + +Alternative batch mode (`synthesize`) is available but not used by default. + +### Audio Format + +- **Encoding**: LINEAR_PCM (16-bit signed integer) +- **Sample Rate**: Configurable (default 16000 Hz) +- **Channels**: Mono +- **Byte Order**: Little-endian + +## Error Handling + +### Initialization Errors +- Server unreachable → RuntimeError with connection details +- Invalid credentials → Authentication error +- Missing dependencies → Import error + +### Runtime Errors +- Empty text → Warning logged, no synthesis +- Synthesis failure → RuntimeError with Riva error message +- Cancellation → Graceful stop, log cancellation + +### Logging Strategy + +- **INFO**: Initialization, configuration +- **DEBUG**: Synthesis progress, chunk details +- **WARN**: Empty text, unusual conditions +- **ERROR**: Failures, exceptions + +## Testing + +### Test Coverage + +1. **Configuration Tests** (`test_config.py`) + - Valid configuration creation + - Missing required parameters + - Default values + - Validation logic + +2. **Extension Tests** (`test_extension.py`) + - Extension initialization + - Config creation from JSON + - Sample rate retrieval + - Client creation + +3. **Client Tests** (`test_extension.py`) + - Client initialization with mocked Riva + - Cancellation handling + - Empty text handling + - Synthesis with mocked responses + +### Running Tests + +```bash +# Install test dependencies +pip install pytest pytest-asyncio + +# Run all tests +pytest nvidia_riva_tts_python/tests/ -v + +# Run with coverage +pytest nvidia_riva_tts_python/tests/ --cov=nvidia_riva_tts_python +``` + +## Performance Considerations + +### Latency +- **First chunk**: ~100-200ms (depends on text length and server) +- **Streaming**: Continuous audio delivery +- **GPU acceleration**: Significantly faster than CPU-only TTS + +### Resource Usage +- **Memory**: Minimal (streaming mode) +- **Network**: gRPC connection to Riva server +- **CPU**: Low (Riva does GPU processing) + +### Optimization Tips +1. Use streaming mode for real-time applications +2. Keep Riva server close to application (low network latency) +3. Reuse client connections (handled by extension) +4. Configure appropriate sample rate for use case + +## Deployment + +### Prerequisites +1. NVIDIA Riva server running (see README.md for setup) +2. Network connectivity to Riva server +3. Python 3.8+ with nvidia-riva-client + +### Configuration Example + +```json +{ + "params": { + "server": "riva-server.example.com:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 22050, + "use_ssl": true + } +} +``` + +### Environment Variables + +```bash +export NVIDIA_RIVA_SERVER=localhost:50051 +``` + +## Future Enhancements + +Potential improvements for future versions: + +1. **SSML Support**: Full SSML tag support for advanced speech control +2. **Voice Cloning**: Custom voice model support +3. **Multi-language**: Automatic language detection +4. **Caching**: Cache frequently synthesized phrases +5. **Metrics**: Detailed performance metrics and monitoring +6. **Fallback**: Automatic fallback to alternative TTS if Riva unavailable + +## References + +- [NVIDIA Riva Documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/index.html) +- [Riva Python Client](https://pypi.org/project/nvidia-riva-client/) +- [TEN Framework TTS Interface](https://github.com/TEN-framework/ten-framework) +- [gRPC Python](https://grpc.io/docs/languages/python/) + +## License + +Apache 2.0 - See LICENSE file in the TEN Framework repository. + +## Contributing + +Contributions are welcome! Please: +1. Follow the existing code style +2. Add tests for new features +3. Update documentation +4. Submit PR to TEN Framework repository + diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/README.md b/agents/ten_packages/extension/nvidia_riva_tts_python/README.md new file mode 100644 index 0000000000..d0452fc9c0 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/README.md @@ -0,0 +1,93 @@ +# NVIDIA Riva TTS Python Extension + +This extension provides text-to-speech functionality using NVIDIA Riva Speech Skills. + +## Features + +- High-quality speech synthesis using NVIDIA Riva +- Support for multiple languages and voices +- Streaming and batch synthesis modes +- SSML support for advanced speech control +- GPU-accelerated inference for low latency + +## Prerequisites + +- NVIDIA Riva server running and accessible +- Python 3.8+ +- nvidia-riva-client package + +## Configuration + +The extension can be configured through your property.json: + +```json +{ + "params": { + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": false + } +} +``` + +### Configuration Options + +**Parameters inside `params` object:** +- `server` (required): Riva server address (format: "host:port") +- `language_code` (required): Language code (e.g., "en-US", "es-ES") +- `voice_name` (required): Voice identifier (e.g., "English-US.Female-1") +- `sample_rate` (optional): Audio sample rate in Hz (default: 16000) +- `use_ssl` (optional): Use SSL for gRPC connection (default: false) + +### Available Voices + +Common voice names include: +- `English-US.Female-1` +- `English-US.Male-1` +- `English-GB.Female-1` +- `Spanish-US.Female-1` + +Check your Riva server documentation for the full list of available voices. + +## Setting up NVIDIA Riva Server + +Follow the [NVIDIA Riva Quick Start Guide](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/quick-start-guide.html) to set up a Riva server. + +Quick setup with Docker: + +```bash +# Download Riva Quick Start scripts +ngc registry resource download-version nvidia/riva/riva_quickstart:2.17.0 + +# Initialize and start Riva +cd riva_quickstart_v2.17.0 +bash riva_init.sh +bash riva_start.sh +``` + +## Environment Variables + +Set the Riva server address via environment variable: + +```bash +export NVIDIA_RIVA_SERVER=localhost:50051 +``` + +## Architecture + +This extension follows the TEN Framework TTS extension pattern: + +- `extension.py`: Main extension class +- `riva_tts.py`: Client implementation with Riva SDK integration +- `config.py`: Configuration model +- `addon.py`: Extension addon registration + +## License + +Apache 2.0 + +## Contributing + +Contributions are welcome! Please submit issues and pull requests to the TEN Framework repository. diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py new file mode 100644 index 0000000000..2718464193 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/__init__.py @@ -0,0 +1,7 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from . import addon + +__all__ = ["addon"] diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py b/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py new file mode 100644 index 0000000000..f880f2d97c --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/addon.py @@ -0,0 +1,18 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("nvidia_riva_tts_python") +class NvidiaRivaTTSExtensionAddon(Addon): + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import NvidiaRivaTTSExtension + + ten_env.log_info("NvidiaRivaTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(NvidiaRivaTTSExtension(name), context) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/config.py new file mode 100644 index 0000000000..ed044bceed --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/config.py @@ -0,0 +1,44 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +from typing import Any +import copy +from pydantic import Field +from pathlib import Path +from ten_ai_base import utils +from ten_ai_base.tts import AsyncTTSConfig + + +class NvidiaRivaTTSConfig(AsyncTTSConfig): + """NVIDIA Riva TTS Config""" + + dump: bool = Field(default=False, description="NVIDIA Riva TTS dump") + dump_path: str = Field( + default_factory=lambda: str(Path(__file__).parent / "nvidia_riva_tts_in.pcm"), + description="NVIDIA Riva TTS dump path", + ) + params: dict[str, Any] = Field( + default_factory=dict, description="NVIDIA Riva TTS params" + ) + + def update_params(self) -> None: + """Update configuration from params dictionary""" + pass + + def to_str(self, sensitive_handling: bool = True) -> str: + """Convert config to string with optional sensitive data handling.""" + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + return f"{config}" + + def validate(self) -> None: + """Validate NVIDIA Riva-specific configuration.""" + if "server" not in self.params or not self.params["server"]: + raise ValueError("Server address is required for NVIDIA Riva TTS") + if "language_code" not in self.params or not self.params["language_code"]: + raise ValueError("Language code is required for NVIDIA Riva TTS") + if "voice_name" not in self.params or not self.params["voice_name"]: + raise ValueError("Voice name is required for NVIDIA Riva TTS") diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py new file mode 100644 index 0000000000..44523b89d9 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/extension.py @@ -0,0 +1,264 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +NVIDIA Riva TTS Extension + +This extension implements text-to-speech using NVIDIA Riva Speech Skills. +It provides high-quality, GPU-accelerated speech synthesis. +""" + +import asyncio +import time +import traceback +from typing import Optional + +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleType, + TTSAudioEndReason, +) +from ten_ai_base.struct import TTSTextInput +from ten_ai_base.tts2 import AsyncTTS2BaseExtension, RequestState +from ten_ai_base.const import LOG_CATEGORY_KEY_POINT, LOG_CATEGORY_VENDOR +from ten_runtime import AsyncTenEnv + +from .config import NvidiaRivaTTSConfig +from .riva_tts import NvidiaRivaTTSClient + + +class NvidiaRivaTTSExtension(AsyncTTS2BaseExtension): + """ + NVIDIA Riva TTS Extension implementation. + + Provides text-to-speech synthesis using NVIDIA Riva's gRPC API. + Inherits all common TTS functionality from AsyncTTS2BaseExtension. + """ + + def __init__(self, name: str) -> None: + super().__init__(name) + self.config: Optional[NvidiaRivaTTSConfig] = None + self.client: Optional[NvidiaRivaTTSClient] = None + self.current_request_id: Optional[str] = None + self.request_start_ts: float = 0 + self.first_chunk_ts: float = 0 + self.request_total_audio_duration: int = 0 + self.flush_request_id: Optional[str] = None + self.last_end_request_id: Optional[str] = None + self.audio_start_sent: set[str] = set() + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + """Initialize the extension""" + await super().on_init(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_init") + + try: + # Load configuration + config_json, _ = await ten_env.get_property_to_json("") + self.config = NvidiaRivaTTSConfig.model_validate_json(config_json) + + ten_env.log_info( + f"config: {self.config.model_dump_json()}", + category=LOG_CATEGORY_KEY_POINT, + ) + + # Create client + self.client = NvidiaRivaTTSClient( + config=self.config, + ten_env=ten_env, + ) + + except Exception as e: + ten_env.log_error(f"on_init failed: {traceback.format_exc()}") + await self.send_tts_error( + request_id="", + error=ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.FATAL_ERROR, + vendor_info={"vendor": "nvidia_riva"}, + ), + ) + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + """Stop the extension""" + await super().on_stop(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_stop") + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + """Deinitialize the extension""" + await super().on_deinit(ten_env) + ten_env.log_debug("NVIDIA Riva TTS on_deinit") + + def vendor(self) -> str: + """Return vendor name""" + return "nvidia_riva" + + def synthesize_audio_sample_rate(self) -> int: + """Return audio sample rate""" + return self.config.params.get("sample_rate", 16000) if self.config else 16000 + + def synthesize_audio_channels(self) -> int: + """Return number of audio channels""" + return 1 + + def synthesize_audio_sample_width(self) -> int: + """Return sample width in bytes""" + return 2 # 16-bit PCM + + async def request_tts(self, t: TTSTextInput) -> None: + """Handle TTS request""" + try: + self.ten_env.log_info( + f"TTS request: text_length={len(t.text)}, " + f"text_input_end={t.text_input_end}, request_id={t.request_id}" + ) + + # Skip if request already completed + if t.request_id == self.flush_request_id: + self.ten_env.log_debug( + f"Request {t.request_id} was flushed, ignoring" + ) + return + + if t.request_id == self.last_end_request_id: + self.ten_env.log_debug( + f"Request {t.request_id} was ended, ignoring" + ) + return + + # Handle new request + is_new_request = self.current_request_id != t.request_id + if is_new_request: + self.ten_env.log_debug(f"New TTS request: {t.request_id}") + self.current_request_id = t.request_id + self.request_total_audio_duration = 0 + self.request_start_ts = time.time() + + if self.client is None: + raise ValueError("TTS client not initialized") + + # Synthesize audio + received_first_chunk = False + async for chunk in self.client.synthesize(t.text, t.request_id): + # Calculate audio duration + duration = self._calculate_audio_duration(len(chunk)) + + self.ten_env.log_debug( + f"receive_audio: duration={duration}ms, request_id={self.current_request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + if not received_first_chunk: + received_first_chunk = True + # Send audio start + if t.request_id not in self.audio_start_sent: + await self.send_tts_audio_start(t.request_id) + self.audio_start_sent.add(t.request_id) + if is_new_request: + # Send TTFB metrics + self.first_chunk_ts = time.time() + elapsed_time = int( + (self.first_chunk_ts - self.request_start_ts) * 1000 + ) + await self.send_tts_ttfb_metrics( + request_id=t.request_id, + ttfb_ms=elapsed_time, + extra_metadata={ + "voice_name": self.config.params["voice_name"], + "language_code": self.config.params["language_code"], + }, + ) + + if t.request_id == self.flush_request_id: + break + + self.request_total_audio_duration += duration + await self.send_tts_audio_data(chunk) + + # Handle completion + if t.text_input_end or t.request_id == self.flush_request_id: + reason = TTSAudioEndReason.REQUEST_END + if t.request_id == self.flush_request_id: + reason = TTSAudioEndReason.INTERRUPTED + + if self.first_chunk_ts > 0: + await self._handle_completed_request(reason) + + except Exception as e: + self.ten_env.log_error(f"Error in request_tts: {traceback.format_exc()}") + await self.send_tts_error( + request_id=t.request_id, + error=ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info={"vendor": "nvidia_riva"}, + ), + ) + + # Check if we've received text_input_end + has_received_text_input_end = False + if t.request_id and t.request_id in self.request_states: + if self.request_states[t.request_id] == RequestState.FINALIZING: + has_received_text_input_end = True + + if has_received_text_input_end: + await self._handle_completed_request(TTSAudioEndReason.ERROR) + + async def cancel_tts(self) -> None: + """Cancel current TTS request""" + self.ten_env.log_info(f"cancel_tts current_request_id: {self.current_request_id}") + if self.current_request_id is not None: + self.flush_request_id = self.current_request_id + + if self.client: + await self.client.cancel() + + if self.current_request_id and self.first_chunk_ts > 0: + await self._handle_completed_request(TTSAudioEndReason.INTERRUPTED) + + async def _handle_completed_request(self, reason: TTSAudioEndReason) -> None: + """Handle completed TTS request""" + if not self.current_request_id: + return + + self.last_end_request_id = self.current_request_id + + # Calculate metrics + request_event_interval = 0 + if self.first_chunk_ts > 0: + request_event_interval = int( + (time.time() - self.first_chunk_ts) * 1000 + ) + + # Send audio end + await self.send_tts_audio_end( + request_id=self.current_request_id, + request_event_interval_ms=request_event_interval, + request_total_audio_duration_ms=self.request_total_audio_duration, + reason=reason, + ) + + self.ten_env.log_debug( + f"Sent tts_audio_end: reason={reason.name}, request_id={self.current_request_id}" + ) + + # Finish request + await self.finish_request(request_id=self.current_request_id, reason=reason) + + # Reset state + self.first_chunk_ts = 0 + self.audio_start_sent.discard(self.current_request_id) + + def _calculate_audio_duration(self, bytes_length: int) -> int: + """Calculate audio duration in milliseconds""" + bytes_per_second = ( + self.synthesize_audio_sample_rate() + * self.synthesize_audio_channels() + * self.synthesize_audio_sample_width() + ) + return int((bytes_length / bytes_per_second) * 1000) diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json b/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json new file mode 100644 index 0000000000..e48a071a40 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/manifest.json @@ -0,0 +1,57 @@ +{ + "type": "extension", + "name": "nvidia_riva_tts_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "server": { + "type": "string" + }, + "language_code": { + "type": "string" + }, + "voice_name": { + "type": "string" + }, + "sample_rate": { + "type": "int64" + }, + "use_ssl": { + "type": "bool" + } + } + } + } + } + } +} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/property.json b/agents/ten_packages/extension/nvidia_riva_tts_python/property.json new file mode 100644 index 0000000000..022a606664 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/property.json @@ -0,0 +1,9 @@ +{ + "params": { + "server": "${env:NVIDIA_RIVA_SERVER|localhost:50051}", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": false + } +} diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt b/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt new file mode 100644 index 0000000000..f178d839f4 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/requirements.txt @@ -0,0 +1,2 @@ +nvidia-riva-client>=2.17.0 +numpy>=1.21.0 diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py b/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py new file mode 100644 index 0000000000..04f7e5921b --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/riva_tts.py @@ -0,0 +1,143 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import AsyncIterator +import numpy as np +import riva.client +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR + +from .config import NvidiaRivaTTSConfig + + +class NvidiaRivaTTSClient: + """NVIDIA Riva TTS Client implementation""" + + def __init__( + self, + config: NvidiaRivaTTSConfig, + ten_env: AsyncTenEnv, + ): + self.config = config + self.ten_env: AsyncTenEnv = ten_env + self._is_cancelled = False + self.auth = None + self.tts_service = None + + try: + # Initialize Riva client + server = config.params["server"] + use_ssl = config.params.get("use_ssl", False) + + self.ten_env.log_info( + f"Initializing NVIDIA Riva TTS client with server: {server}, SSL: {use_ssl}", + category=LOG_CATEGORY_VENDOR, + ) + + self.auth = riva.client.Auth(ssl_cert=None, use_ssl=use_ssl, uri=server) + self.tts_service = riva.client.SpeechSynthesisService(self.auth) + + self.ten_env.log_info( + "NVIDIA Riva TTS client initialized successfully", + category=LOG_CATEGORY_VENDOR, + ) + except Exception as e: + ten_env.log_error( + f"Error when initializing NVIDIA Riva TTS: {e}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError(f"Error when initializing NVIDIA Riva TTS: {e}") from e + + async def cancel(self): + """Cancel the current TTS request""" + self.ten_env.log_debug("NVIDIA Riva TTS: cancel() called.") + self._is_cancelled = True + + async def synthesize(self, text: str, request_id: str) -> AsyncIterator[bytes]: + """ + Synthesize speech from text using NVIDIA Riva TTS. + + Args: + text: Text to synthesize + request_id: Unique request identifier + + Yields: + Audio data as bytes (PCM format) + """ + self._is_cancelled = False + + if not self.tts_service: + self.ten_env.log_error( + f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"NVIDIA Riva TTS: service not initialized for request_id: {request_id}" + ) + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"NVIDIA Riva TTS: empty text for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + return + + try: + language_code = self.config.params["language_code"] + voice_name = self.config.params["voice_name"] + sample_rate = self.config.params.get("sample_rate", 16000) + + self.ten_env.log_debug( + f"NVIDIA Riva TTS: synthesizing text (length: {len(text)}) " + f"with voice: {voice_name}, language: {language_code}, " + f"sample_rate: {sample_rate}, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + # Use streaming synthesis for lower latency + responses = self.tts_service.synthesize_online( + text, + voice_name=voice_name, + language_code=language_code, + sample_rate_hz=sample_rate, + encoding=riva.client.AudioEncoding.LINEAR_PCM, + ) + + # Stream audio chunks + for response in responses: + if self._is_cancelled: + self.ten_env.log_debug( + f"Cancellation detected, stopping TTS stream for request_id: {request_id}" + ) + break + + # Convert audio bytes to numpy array and back to bytes + # This ensures proper format + audio_data = np.frombuffer(response.audio, dtype=np.int16) + + self.ten_env.log_debug( + f"NVIDIA Riva TTS: yielding audio chunk, " + f"length: {len(audio_data)} samples, request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + yield audio_data.tobytes() + + if not self._is_cancelled: + self.ten_env.log_debug( + f"NVIDIA Riva TTS: synthesis completed for request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + + except Exception as e: + error_message = str(e) + self.ten_env.log_error( + f"NVIDIA Riva TTS: error during synthesis: {error_message}, " + f"request_id: {request_id}", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"NVIDIA Riva TTS synthesis failed: {error_message}" + ) from e diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py new file mode 100644 index 0000000000..b8c07eef1c --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/__init__.py @@ -0,0 +1,4 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py new file mode 100644 index 0000000000..3384bb95d5 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_compliance.py @@ -0,0 +1,294 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +""" +Compliance tests to ensure the extension correctly implements NVIDIA Riva TTS API. +These tests validate against the official NVIDIA Riva client API specifications. +""" +import pytest +from unittest.mock import Mock, patch, MagicMock +import numpy as np +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig +from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient + + +class TestNvidiaRivaAPICompliance: + """Test compliance with NVIDIA Riva TTS API specifications""" + + @pytest.fixture + def mock_ten_env(self): + """Create a mock TenEnv""" + env = Mock() + env.log_info = Mock() + env.log_debug = Mock() + env.log_warn = Mock() + env.log_error = Mock() + return env + + @pytest.fixture + def valid_config(self): + """Create a valid configuration""" + return NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": False, + } + ) + + def test_auth_initialization_parameters(self, valid_config, mock_ten_env): + """Verify Auth is initialized with correct parameters per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Verify Auth called with correct parameters + mock_auth.assert_called_once_with( + ssl_cert=None, + use_ssl=False, + uri="localhost:50051" + ) + + def test_speech_synthesis_service_initialization(self, valid_config, mock_ten_env): + """Verify SpeechSynthesisService is initialized with Auth object""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + mock_auth_instance = Mock() + mock_auth.return_value = mock_auth_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Verify SpeechSynthesisService called with Auth instance + mock_service.assert_called_once_with(mock_auth_instance) + + @pytest.mark.asyncio + async def test_synthesize_online_parameters(self, valid_config, mock_ten_env): + """Verify synthesize_online is called with correct parameters per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: + + # Setup mocks + mock_service_instance = Mock() + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + mock_encoding.LINEAR_PCM = "LINEAR_PCM" + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize text + text = "Hello world" + chunks = [chunk async for chunk in client.synthesize(text, "test_request")] + + # Verify synthesize_online called with correct parameters + mock_service_instance.synthesize_online.assert_called_once_with( + text, + voice_name="English-US.Female-1", + language_code="en-US", + sample_rate_hz=16000, + encoding="LINEAR_PCM" + ) + + @pytest.mark.asyncio + async def test_audio_encoding_linear_pcm(self, valid_config, mock_ten_env): + """Verify LINEAR_PCM encoding is used per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.AudioEncoding') as mock_encoding: + + mock_service_instance = Mock() + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + mock_encoding.LINEAR_PCM = "LINEAR_PCM" + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify encoding parameter + call_args = mock_service_instance.synthesize_online.call_args + assert call_args[1]['encoding'] == "LINEAR_PCM" + + @pytest.mark.asyncio + async def test_audio_format_int16(self, valid_config, mock_ten_env): + """Verify audio is processed as int16 per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create mock audio data (int16 format) + mock_audio = np.array([100, -100, 200, -200], dtype=np.int16).tobytes() + mock_response = Mock() + mock_response.audio = mock_audio + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify output is bytes + assert len(chunks) == 1 + assert isinstance(chunks[0], bytes) + + # Verify can be converted back to int16 + audio_array = np.frombuffer(chunks[0], dtype=np.int16) + assert audio_array.dtype == np.int16 + assert len(audio_array) == 4 + + @pytest.mark.asyncio + async def test_streaming_response_iteration(self, valid_config, mock_ten_env): + """Verify streaming responses are iterated correctly per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create multiple response chunks + mock_responses = [] + for i in range(3): + mock_response = Mock() + mock_response.audio = np.array([i] * 10, dtype=np.int16).tobytes() + mock_responses.append(mock_response) + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=mock_responses) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize + chunks = [chunk async for chunk in client.synthesize("Test", "req1")] + + # Verify all chunks received + assert len(chunks) == 3 + for chunk in chunks: + assert isinstance(chunk, bytes) + assert len(chunk) > 0 + + def test_required_config_parameters(self): + """Verify all required parameters are validated per Riva API""" + # Missing server + with pytest.raises(ValueError, match="Server address is required"): + config = NvidiaRivaTTSConfig( + params={"language_code": "en-US", "voice_name": "English-US.Female-1"} + ) + config.validate() + + # Missing language_code + with pytest.raises(ValueError, match="Language code is required"): + config = NvidiaRivaTTSConfig( + params={"server": "localhost:50051", "voice_name": "English-US.Female-1"} + ) + config.validate() + + # Missing voice_name + with pytest.raises(ValueError, match="Voice name is required"): + config = NvidiaRivaTTSConfig( + params={"server": "localhost:50051", "language_code": "en-US"} + ) + config.validate() + + def test_optional_config_parameters(self, valid_config): + """Verify optional parameters have correct defaults per Riva API""" + # sample_rate defaults to 16000 + assert valid_config.params.get("sample_rate", 16000) == 16000 + + # use_ssl defaults to False + assert valid_config.params.get("use_ssl", False) is False + + def test_supported_sample_rates(self): + """Verify common sample rates are supported per Riva API""" + supported_rates = [8000, 16000, 22050, 24000, 44100, 48000] + + for rate in supported_rates: + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": rate, + } + ) + config.validate() # Should not raise + assert config.params["sample_rate"] == rate + + def test_ssl_configuration(self, mock_ten_env): + """Verify SSL can be enabled per Riva API""" + config_with_ssl = NvidiaRivaTTSConfig( + params={ + "server": "secure-server:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "use_ssl": True, + } + ) + + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') as mock_auth, \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=config_with_ssl, ten_env=mock_ten_env) + + # Verify SSL enabled in Auth + call_args = mock_auth.call_args + assert call_args[1]['use_ssl'] is True + + @pytest.mark.asyncio + async def test_empty_text_handling(self, valid_config, mock_ten_env): + """Verify empty text is handled gracefully per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService'): + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Empty string + chunks = [chunk async for chunk in client.synthesize("", "req1")] + assert len(chunks) == 0 + + # Whitespace only + chunks = [chunk async for chunk in client.synthesize(" ", "req1")] + assert len(chunks) == 0 + + @pytest.mark.asyncio + async def test_cancellation_support(self, valid_config, mock_ten_env): + """Verify cancellation is supported per Riva API""" + with patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth'), \ + patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') as mock_service: + + # Create multiple responses to simulate long synthesis + mock_responses = [Mock(audio=b'\x00\x01' * 100) for _ in range(10)] + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=mock_responses) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Start synthesis and cancel mid-stream + chunks = [] + async for i, chunk in enumerate(client.synthesize("Long text", "req1")): + chunks.append(chunk) + if i == 2: # Cancel after 3 chunks + await client.cancel() + + # Verify cancellation stopped the stream + assert len(chunks) < 10 # Should not receive all chunks + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) + diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py new file mode 100644 index 0000000000..b29bf51adc --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_config.py @@ -0,0 +1,67 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +import pytest +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig + + +def test_config_validation(): + """Test configuration validation""" + # Valid config + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + } + ) + config.validate() # Should not raise + + # Missing server + with pytest.raises(ValueError, match="Server address is required"): + config = NvidiaRivaTTSConfig( + params={ + "language_code": "en-US", + "voice_name": "English-US.Female-1", + } + ) + config.validate() + + # Missing language_code + with pytest.raises(ValueError, match="Language code is required"): + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "voice_name": "English-US.Female-1", + } + ) + config.validate() + + # Missing voice_name + with pytest.raises(ValueError, match="Voice name is required"): + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + } + ) + config.validate() + + +def test_config_defaults(): + """Test default configuration values""" + config = NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + } + ) + + assert config.dump is False + assert "nvidia_riva_tts_in.pcm" in config.dump_path + assert config.params["server"] == "localhost:50051" + assert config.params.get("sample_rate", 16000) == 16000 + assert config.params.get("use_ssl", False) is False diff --git a/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py new file mode 100644 index 0000000000..517a600945 --- /dev/null +++ b/agents/ten_packages/extension/nvidia_riva_tts_python/tests/test_extension.py @@ -0,0 +1,134 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from nvidia_riva_tts_python.extension import NvidiaRivaTTSExtension +from nvidia_riva_tts_python.config import NvidiaRivaTTSConfig +from nvidia_riva_tts_python.riva_tts import NvidiaRivaTTSClient + + +@pytest.fixture +def mock_ten_env(): + """Create a mock TenEnv for testing""" + env = Mock() + env.log_info = Mock() + env.log_debug = Mock() + env.log_warn = Mock() + env.log_error = Mock() + return env + + +@pytest.fixture +def valid_config(): + """Create a valid configuration for testing""" + return NvidiaRivaTTSConfig( + params={ + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000, + "use_ssl": False, + } + ) + + +class TestNvidiaRivaTTSExtension: + """Test cases for NvidiaRivaTTSExtension""" + + def test_extension_initialization(self): + """Test extension can be initialized""" + extension = NvidiaRivaTTSExtension("test_extension") + assert extension is not None + assert extension.vendor() == "nvidia_riva" + + @pytest.mark.asyncio + async def test_create_config(self): + """Test configuration creation from JSON""" + extension = NvidiaRivaTTSExtension("test_extension") + config_json = """{ + "params": { + "server": "localhost:50051", + "language_code": "en-US", + "voice_name": "English-US.Female-1", + "sample_rate": 16000 + } + }""" + + config = await extension.create_config(config_json) + assert isinstance(config, NvidiaRivaTTSConfig) + assert config.params["server"] == "localhost:50051" + assert config.params["language_code"] == "en-US" + + def test_synthesize_audio_sample_rate(self, valid_config): + """Test sample rate retrieval""" + extension = NvidiaRivaTTSExtension("test_extension") + extension.config = valid_config + + sample_rate = extension.synthesize_audio_sample_rate() + assert sample_rate == 16000 + + +class TestNvidiaRivaTTSClient: + """Test cases for NvidiaRivaTTSClient""" + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + def test_client_initialization(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test client initialization""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + assert client is not None + assert client.config == valid_config + mock_auth.assert_called_once() + mock_service.assert_called_once() + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_cancel(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test cancellation""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + await client.cancel() + assert client._is_cancelled is True + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_synthesize_empty_text(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test synthesis with empty text""" + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + + # Should return without yielding anything + result = [chunk async for chunk in client.synthesize("", "test_request")] + assert len(result) == 0 + + @patch('nvidia_riva_tts_python.riva_tts.riva.client.Auth') + @patch('nvidia_riva_tts_python.riva_tts.riva.client.SpeechSynthesisService') + @pytest.mark.asyncio + async def test_synthesize_with_text(self, mock_service, mock_auth, valid_config, mock_ten_env): + """Test synthesis with valid text""" + # Mock the service response + mock_response = Mock() + mock_response.audio = b'\x00\x01' * 100 # Mock audio data + + mock_service_instance = Mock() + mock_service_instance.synthesize_online = Mock(return_value=[mock_response]) + mock_service.return_value = mock_service_instance + + client = NvidiaRivaTTSClient(config=valid_config, ten_env=mock_ten_env) + client.tts_service = mock_service_instance + + # Synthesize text + chunks = [chunk async for chunk in client.synthesize("Hello world", "test_request")] + + assert len(chunks) > 0 + assert isinstance(chunks[0], bytes) + mock_service_instance.synthesize_online.assert_called_once() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) +