diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/README.md b/ai_agents/agents/ten_packages/extension/camb_tts_python/README.md new file mode 100644 index 0000000000..f10c8dc27a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/README.md @@ -0,0 +1,57 @@ +# camb_tts_python + +Camb.ai TTS extension for TEN Framework using the MARS-8 text-to-speech API. + +## Features + +- MARS-8 model family (mars-8, mars-8-flash, mars-8-instruct) +- 140+ languages supported +- Voice cloning capabilities +- Real-time HTTP streaming +- High-quality 24kHz audio output + +## API + +Refer to `api` definition in [manifest.json](manifest.json) and default values in [property.json](property.json). + +### Configuration Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| api_key | string | (required) | Camb.ai API key | +| voice_id | int32 | 2681 | Voice ID (default: Attic voice) | +| language | string | "en-us" | Language code (BCP-47 format) | +| speech_model | string | "mars-8-flash" | Model selection | +| speed | float64 | 1.0 | Speech speed multiplier | +| format | string | "pcm_s16le" | Output format | +| endpoint | string | (optional) | API endpoint override | + +### Available Models + +- `mars-8` - Default balanced model +- `mars-8-flash` - Faster inference (recommended) +- `mars-8-instruct` - Supports user instructions + +## Development + +### Setup + +1. Get your API key from [Camb.ai](https://camb.ai) +2. Set environment variable: + ```bash + export CAMB_API_KEY=your_key_here + ``` + +### Build + +Follow the standard TEN Framework extension build process. + +### Unit test + +Run tests using the standard TEN Framework testing approach. + +## Resources + +- [Camb.ai API Documentation](https://camb.mintlify.app/) +- [Getting Started](https://camb.mintlify.app/getting-started) +- [API Reference](https://camb.mintlify.app/api-reference/endpoint/create-tts-stream) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/__init__.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/__init__.py new file mode 100644 index 0000000000..72593ab225 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/addon.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/addon.py new file mode 100644 index 0000000000..122d30c003 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/addon.py @@ -0,0 +1,20 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("camb_tts_python") +class CambTTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import CambTTSExtension + + ten_env.log_info("CambTTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(CambTTSExtension(name), context) diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/camb_tts.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/camb_tts.py new file mode 100644 index 0000000000..daa9a16d5c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/camb_tts.py @@ -0,0 +1,215 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any, AsyncIterator, Tuple +from httpx import AsyncClient, Timeout, Limits + +from .config import CambTTSConfig +from ten_runtime import AsyncTenEnv +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_ai_base.struct import TTS2HttpResponseEventType +from ten_ai_base.tts2_http import AsyncTTS2HttpClient + + +BYTES_PER_SAMPLE = 2 +NUMBER_OF_CHANNELS = 1 +SAMPLE_RATE = 24000 + + +class CambTTSClient(AsyncTTS2HttpClient): + def __init__( + self, + config: CambTTSConfig, + ten_env: AsyncTenEnv, + ): + super().__init__() + self.config = config + self.api_key = config.params.get("api_key", "") + self.ten_env: AsyncTenEnv = ten_env + self._is_cancelled = False + self.endpoint = config.params.get( + "endpoint", "https://client.camb.ai/apis/tts-stream" + ) + self.headers = { + "x-api-key": self.api_key, + "Content-Type": "application/json", + "Accept": "application/json", + } + # Camb.ai TTS requires longer timeout (minimum 60s recommended) + self.client = AsyncClient( + timeout=Timeout(timeout=60.0), + limits=Limits( + max_connections=100, + max_keepalive_connections=20, + keepalive_expiry=600.0, # 10 minutes keepalive + ), + http2=True, # Enable HTTP/2 if server supports it + ) + + async def cancel(self): + self.ten_env.log_debug("CambTTS: cancel() called.") + self._is_cancelled = True + + async def get( + self, text: str, request_id: str + ) -> AsyncIterator[Tuple[bytes | None, TTS2HttpResponseEventType]]: + """Process a single TTS request in serial manner""" + self._is_cancelled = False + if not self.client: + self.ten_env.log_error( + f"CambTTS: client not initialized for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + raise RuntimeError( + f"CambTTS: client not initialized for request_id: {request_id}." + ) + + if len(text.strip()) == 0: + self.ten_env.log_warn( + f"CambTTS: empty text for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + # Validate text length (Camb.ai requires 3-3000 characters) + text_len = len(text.strip()) + if text_len < 3: + self.ten_env.log_warn( + f"CambTTS: text too short ({text_len} chars, min 3) for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield None, TTS2HttpResponseEventType.END + return + + if text_len > 3000: + self.ten_env.log_warn( + f"CambTTS: text too long ({text_len} chars, max 3000), truncating for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + text = text[:3000] + + try: + # Build payload with Camb.ai's nested structure + payload = { + "text": text, + "voice_id": self.config.params.get("voice_id", 2681), + "language": self.config.params.get("language", "en-us"), + "speech_model": self.config.params.get("speech_model", "mars-8-flash"), + "output_configuration": { + "format": self.config.params.get("format", "pcm_s16le"), + }, + "voice_settings": { + "speed": self.config.params.get("speed", 1.0), + }, + } + + async with self.client.stream( + "POST", + self.endpoint, + headers=self.headers, + json=payload, + ) as response: + # Check for HTTP errors before streaming + if response.status_code == 401: + error_message = "Invalid Camb.ai API key. Set CAMB_API_KEY environment variable with your API key from https://camb.ai" + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.INVALID_KEY_ERROR + return + + if response.status_code == 403: + voice_id = self.config.params.get("voice_id", 2681) + error_message = f"Voice ID {voice_id} is not accessible with your API key. Use list_voices() to see available voices." + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR + return + + if response.status_code == 429: + error_message = "Rate limit exceeded. Please wait before making more requests." + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR + return + + if response.status_code >= 400: + error_body = await response.aread() + error_message = f"API Error {response.status_code}: {error_body.decode('utf-8', errors='replace')}" + self.ten_env.log_error( + f"CambTTS: {error_message} for request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR + return + + async for chunk in response.aiter_bytes(chunk_size=8192): + if self._is_cancelled: + self.ten_env.log_debug( + f"Cancellation flag detected, sending flush event and stopping TTS stream of request_id: {request_id}." + ) + yield None, TTS2HttpResponseEventType.FLUSH + break + + self.ten_env.log_debug( + f"CambTTS: sending EVENT_TTS_RESPONSE, length: {len(chunk)} of request_id: {request_id}." + ) + + if len(chunk) > 0: + yield bytes(chunk), TTS2HttpResponseEventType.RESPONSE + else: + yield None, TTS2HttpResponseEventType.END + + if not self._is_cancelled: + self.ten_env.log_debug( + f"CambTTS: sending EVENT_TTS_END of request_id: {request_id}." + ) + yield None, TTS2HttpResponseEventType.END + + except Exception as e: + # Check if it's an API key authentication error + error_message = str(e) + self.ten_env.log_error( + f"vendor_error: {error_message} of request_id: {request_id}.", + category=LOG_CATEGORY_VENDOR, + ) + if "401" in error_message: + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.INVALID_KEY_ERROR + else: + yield error_message.encode( + "utf-8" + ), TTS2HttpResponseEventType.ERROR + + async def clean(self): + # In this new model, most cleanup is handled by the connection object's lifecycle. + # This can be used for any additional cleanup if needed. + self.ten_env.log_debug("CambTTS: clean() called.") + try: + await self.client.aclose() + finally: + pass + + def get_extra_metadata(self) -> dict[str, Any]: + """Return extra metadata for TTFB metrics.""" + return { + "voice_id": self.config.params.get("voice_id", 2681), + "speech_model": self.config.params.get("speech_model", "mars-8-flash"), + } diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/config.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/config.py new file mode 100644 index 0000000000..1f8b2a9d99 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/config.py @@ -0,0 +1,57 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from typing import Any +import copy +from pathlib import Path +from ten_ai_base import utils +from ten_ai_base.tts2_http import AsyncTTS2HttpConfig + +from pydantic import Field + + +class CambTTSConfig(AsyncTTS2HttpConfig): + """Camb.ai TTS Config""" + + # Debug and logging + dump: bool = Field(default=False, description="Camb TTS dump") + dump_path: str = Field( + default_factory=lambda: str(Path(__file__).parent / "camb_tts_in.pcm"), + description="Camb TTS dump path", + ) + params: dict[str, Any] = Field( + default_factory=dict, description="Camb TTS params" + ) + + def update_params(self) -> None: + """Update configuration from params dictionary""" + # Keys to exclude from params after processing (not passthrough params) + blacklist_keys = [ + "text", + "endpoint", + ] + + # Remove blacklisted keys from params + for key in blacklist_keys: + if key in self.params: + del self.params[key] + + def to_str(self, sensitive_handling: bool = True) -> str: + """Convert config to string with optional sensitive data handling.""" + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + + # Encrypt sensitive fields in params + if config.params and "api_key" in config.params: + config.params["api_key"] = utils.encrypt(config.params["api_key"]) + + return f"{config}" + + def validate(self) -> None: + """Validate Camb-specific configuration.""" + if "api_key" not in self.params or not self.params["api_key"]: + raise ValueError("API key is required for Camb TTS") diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/extension.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/extension.py new file mode 100644 index 0000000000..c9906df16d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/extension.py @@ -0,0 +1,61 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +""" +Camb.ai TTS Extension + +This extension implements text-to-speech using the Camb.ai MARS-8 TTS API. +It extends the AsyncTTS2HttpExtension for HTTP-based TTS services. +""" + +from ten_ai_base.tts2_http import ( + AsyncTTS2HttpExtension, + AsyncTTS2HttpConfig, + AsyncTTS2HttpClient, +) +from ten_runtime import AsyncTenEnv + +from .config import CambTTSConfig +from .camb_tts import CambTTSClient + + +class CambTTSExtension(AsyncTTS2HttpExtension): + """ + Camb.ai TTS Extension implementation. + + Provides text-to-speech synthesis using Camb.ai's MARS-8 HTTP API. + Inherits all common HTTP TTS functionality from AsyncTTS2HttpExtension. + """ + + def __init__(self, name: str) -> None: + super().__init__(name) + # Type hints for better IDE support + self.config: CambTTSConfig = None + self.client: CambTTSClient = None + + # ============================================================ + # Required method implementations + # ============================================================ + + async def create_config(self, config_json_str: str) -> AsyncTTS2HttpConfig: + """Create Camb TTS configuration from JSON string.""" + return CambTTSConfig.model_validate_json(config_json_str) + + async def create_client( + self, config: AsyncTTS2HttpConfig, ten_env: AsyncTenEnv + ) -> AsyncTTS2HttpClient: + """Create Camb TTS client.""" + return CambTTSClient(config=config, ten_env=ten_env) + + def vendor(self) -> str: + """Return vendor name.""" + return "camb" + + def synthesize_audio_sample_rate(self) -> int: + """Return the sample rate for synthesized audio. + + Camb.ai outputs 24kHz audio. + """ + return 24000 diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/manifest.json b/ai_agents/agents/ten_packages/extension/camb_tts_python/manifest.json new file mode 100644 index 0000000000..8e07ac33da --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/manifest.json @@ -0,0 +1,64 @@ +{ + "type": "extension", + "name": "camb_tts_python", + "version": "0.1.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "**.tent", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "voice_id": { + "type": "int32" + }, + "language": { + "type": "string" + }, + "speech_model": { + "type": "string" + }, + "speed": { + "type": "float64" + }, + "format": { + "type": "string" + }, + "endpoint": { + "type": "string" + } + } + } + } + } + } +} diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/property.json b/ai_agents/agents/ten_packages/extension/camb_tts_python/property.json new file mode 100644 index 0000000000..6548168488 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/property.json @@ -0,0 +1,12 @@ +{ + "dump": false, + "dump_path": "./", + "params": { + "api_key": "${env:CAMB_API_KEY|}", + "voice_id": 2681, + "language": "en-us", + "speech_model": "mars-8-flash", + "speed": 1.0, + "format": "pcm_s16le" + } +} diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/requirements.txt b/ai_agents/agents/ten_packages/extension/camb_tts_python/requirements.txt new file mode 100644 index 0000000000..0654e854dd --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/requirements.txt @@ -0,0 +1,3 @@ +asyncio +httpx +pydantic>=2.0.0 diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/__init__.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/bin/start new file mode 100755 index 0000000000..b736ea0de1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest -s tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/conftest.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/conftest.py new file mode 100644 index 0000000000..9cf1c0353e --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/conftest.py @@ -0,0 +1,98 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +import threading +from typing_extensions import override +import pytest +from ten_runtime import ( + App, + TenEnv, +) + + +class FakeApp(App): + def __init__(self): + super().__init__() + self.event: threading.Event | None = None + + # In the case of a fake app, we use `on_init` to allow the blocked testing + # fixture to continue execution, rather than using `on_configure`. The + # reason is that in the TEN runtime C core, the relationship between the + # addon manager and the (fake) app is bound after `on_configure_done` is + # called. So we only need to let the testing fixture continue execution + # after this action in the TEN runtime C core, and at the upper layer + # timing, the earliest point is within the `on_init()` function of the upper + # TEN app. Therefore, we release the testing fixture lock within the user + # layer's `on_init()` of the TEN app. + def on_init(self, ten_env: TenEnv) -> None: + assert self.event + self.event.set() + + ten_env.on_init_done() + + @override + def on_configure(self, ten_env: TenEnv) -> None: + ten_env.init_property_from_json( + json.dumps( + { + "ten": { + "log": { + "handlers": [ + { + "matchers": [{"level": "debug"}], + "formatter": { + "type": "plain", + "colored": True, + }, + "emitter": { + "type": "console", + "config": {"stream": "stdout"}, + }, + } + ] + } + } + } + ), + ) + + ten_env.on_configure_done() + + +class FakeAppCtx: + def __init__(self, event: threading.Event): + self.fake_app: FakeApp | None = None + self.event = event + + +def run_fake_app(fake_app_ctx: FakeAppCtx): + app = FakeApp() + app.event = fake_app_ctx.event + fake_app_ctx.fake_app = app + app.run(False) + + +@pytest.fixture(scope="session", autouse=True) +def global_setup_and_teardown(): + event = threading.Event() + fake_app_ctx = FakeAppCtx(event) + + fake_app_thread = threading.Thread( + target=run_fake_app, args=(fake_app_ctx,) + ) + fake_app_thread.start() + + event.wait() + + assert fake_app_ctx.fake_app is not None + + # Yield control to the test; after the test execution is complete, continue + # with the teardown process. + yield + + # Teardown part. + fake_app_ctx.fake_app.close() + fake_app_thread.join() diff --git a/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/test_basic.py b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/test_basic.py new file mode 100644 index 0000000000..4ff7a73e07 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/camb_tts_python/tests/test_basic.py @@ -0,0 +1,325 @@ +import sys +from pathlib import Path + +# Add project root to sys.path to allow running tests from this directory +# The project root is 6 levels up from the parent directory of this file. +project_root = str(Path(__file__).resolve().parents[6]) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from pathlib import Path +import json +from unittest.mock import patch, AsyncMock +import os +import asyncio +import filecmp +import shutil +import threading + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput, TTSFlush, TTS2HttpResponseEventType + + +# ================ test dump file functionality ================ +class ExtensionTesterDump(ExtensionTester): + def __init__(self): + super().__init__() + # Use a fixed path as requested by the user. + self.dump_dir = "./dump/" + # Use a unique name for the file generated by the test to avoid collision + # with the file generated by the extension. + self.test_dump_file_path = os.path.join( + self.dump_dir, "test_manual_dump.pcm" + ) + self.audio_end_received = False + self.received_audio_chunks = [] + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request.""" + ten_env_tester.log_info("Dump test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_1", + text="Hello from Camb AI, this is a test", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + """Receives audio frames and collects their data using the lock/unlock pattern.""" + buf = audio_frame.lock_buf() + try: + copied_data = bytes(buf) + self.received_audio_chunks.append(copied_data) + finally: + audio_frame.unlock_buf(buf) + + def write_test_dump_file(self): + """Writes the collected audio chunks to a file.""" + with open(self.test_dump_file_path, "wb") as f: + for chunk in self.received_audio_chunks: + f.write(chunk) + + def find_tts_dump_file(self) -> str | None: + """Find the dump file created by the TTS extension in the fixed dump directory.""" + if not os.path.exists(self.dump_dir): + return None + for filename in os.listdir(self.dump_dir): + if filename.endswith(".pcm") and filename != os.path.basename( + self.test_dump_file_path + ): + return os.path.join(self.dump_dir, filename) + return None + + +@patch("camb_tts_python.extension.CambTTSClient") +def test_dump_functionality(MockCambTTSClient): + """Tests that the dump file from the TTS extension matches the audio received by the test extension.""" + print("Starting test_dump_functionality with mock...") + + # --- Directory Setup --- + DUMP_PATH = "./dump/" + + # Clean up directory before the test, in case of previous failed runs. + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + os.makedirs(DUMP_PATH) + + # --- Mock Configuration --- + mock_instance = MockCambTTSClient.return_value + mock_instance.clean = AsyncMock() + + # Create some fake audio data to be streamed + fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20 + fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20 + + # This async generator simulates the TTS client's get() method + async def mock_get_audio_stream(text: str, request_id: str | None = None): + yield (fake_audio_chunk_1, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.01) + yield (fake_audio_chunk_2, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.01) + yield (None, TTS2HttpResponseEventType.END) + + mock_instance.get.side_effect = mock_get_audio_stream + + # --- Test Setup --- + tester = ExtensionTesterDump() + + dump_config = { + "dump": True, + "dump_path": DUMP_PATH, + "params": { + "api_key": "test_api_key", + }, + } + + tester.set_test_mode_single("camb_tts_python", json.dumps(dump_config)) + + print("Running dump test...") + tester.run() + print("Dump test completed.") + + # --- Verification --- + assert tester.audio_end_received, "Expected to receive tts_audio_end" + assert ( + len(tester.received_audio_chunks) > 0 + ), "Expected to receive audio chunks" + + tester.write_test_dump_file() + + tts_dump_file = tester.find_tts_dump_file() + assert ( + tts_dump_file is not None + ), f"Expected to find a TTS dump file in {DUMP_PATH}" + assert os.path.exists( + tts_dump_file + ), f"TTS dump file should exist: {tts_dump_file}" + + print( + f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}" + ) + assert filecmp.cmp( + tester.test_dump_file_path, tts_dump_file, shallow=False + ), "Test dump file and TTS dump file should have the same content" + + print( + f"Dump functionality test passed: received {len(tester.received_audio_chunks)} audio chunks" + ) + print(f" Test file: {tester.test_dump_file_path}") + print(f" TTS dump file: {tts_dump_file}") + + # --- Cleanup --- + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + + +# ================ test flush logic ================ +class ExtensionTesterFlush(ExtensionTester): + def __init__(self): + super().__init__() + self.ten_env: TenEnvTester | None = None + self.audio_start_received = False + self.first_audio_frame_received = False + self.flush_start_received = False + self.audio_end_received = False + self.flush_end_received = False + self.audio_end_reason = "" + self.total_audio_duration_from_event = 0 + self.received_audio_bytes = 0 + self.sample_rate = 24000 # Camb TTS sample rate + self.bytes_per_sample = 2 # 16-bit + self.channels = 1 + self.audio_received_after_flush_end = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + self.ten_env = ten_env_tester + ten_env_tester.log_info("Flush test started, sending long TTS request.") + tts_input = TTSTextInput( + request_id="tts_request_for_flush", + text="This is a very long text designed to generate a continuous stream of audio, providing enough time to send a flush command.", + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + if self.flush_end_received: + ten_env.log_error("Received audio frame after tts_flush_end!") + self.audio_received_after_flush_end = True + + if not self.first_audio_frame_received: + self.first_audio_frame_received = True + ten_env.log_info("First audio frame received, sending flush data.") + flush_data = Data.create("tts_flush") + flush_data.set_property_from_json( + None, + TTSFlush(flush_id="tts_request_for_flush").model_dump_json(), + ) + ten_env.send_data(flush_data) + + buf = audio_frame.lock_buf() + try: + self.received_audio_bytes += len(buf) + finally: + audio_frame.unlock_buf(buf) + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "tts_audio_start": + self.audio_start_received = True + return + + json_str, _ = data.get_property_to_json(None) + if not json_str: + return + payload = json.loads(json_str) + ten_env.log_info(f"on_data payload: {payload}") + + if name == "tts_flush_start": + self.flush_start_received = True + return + + if name == "tts_audio_end": + self.audio_end_received = True + self.audio_end_reason = payload.get("reason") + self.total_audio_duration_from_event = payload.get( + "request_total_audio_duration_ms" + ) + + elif name == "tts_flush_end": + self.flush_end_received = True + + def stop_test_later(): + ten_env.log_info("Waited after flush_end, stopping test now.") + ten_env.stop_test() + + timer = threading.Timer(0.5, stop_test_later) + timer.start() + + def get_calculated_audio_duration_ms(self) -> int: + duration_sec = self.received_audio_bytes / ( + self.sample_rate * self.bytes_per_sample * self.channels + ) + return int(duration_sec * 1000) + + +@patch("camb_tts_python.extension.CambTTSClient") +def test_flush_logic(MockCambTTSClient): + """ + Tests that sending a flush command during TTS streaming correctly stops + the audio and sends the appropriate events. + """ + print("Starting test_flush_logic with mock...") + + mock_instance = MockCambTTSClient.return_value + mock_instance.clean = AsyncMock() + mock_instance.cancel = AsyncMock() + + async def mock_get_long_audio_stream( + text: str, request_id: str | None = None + ): + for _ in range(20): + if mock_instance.cancel.called: + print("Mock detected cancel call, sending EVENT_TTS_FLUSH.") + yield (None, TTS2HttpResponseEventType.FLUSH) + return + yield (b"\x11\x22\x33" * 100, TTS2HttpResponseEventType.RESPONSE) + await asyncio.sleep(0.1) + + yield (None, TTS2HttpResponseEventType.END) + + mock_instance.get.side_effect = mock_get_long_audio_stream + + config = { + "params": { + "api_key": "test_api_key", + }, + } + tester = ExtensionTesterFlush() + tester.set_test_mode_single("camb_tts_python", json.dumps(config)) + + print("Running flush logic test...") + tester.run() + print("Flush logic test completed.") + + assert tester.audio_start_received, "Did not receive tts_audio_start." + assert tester.first_audio_frame_received, "Did not receive any audio frame." + assert tester.audio_end_received, "Did not receive tts_audio_end." + assert tester.flush_end_received, "Did not receive tts_flush_end." + assert ( + not tester.audio_received_after_flush_end + ), "Received audio after tts_flush_end." + + calculated_duration = tester.get_calculated_audio_duration_ms() + event_duration = tester.total_audio_duration_from_event + print( + f"calculated_duration: {calculated_duration}, event_duration: {event_duration}" + ) + assert ( + abs(calculated_duration - event_duration) < 10 + ), f"Mismatch in audio duration. Calculated: {calculated_duration}ms, From event: {event_duration}ms" + + print("Flush logic test passed successfully.")