diff --git a/docs/concepts/multimodal/text_to_speech.md b/docs/concepts/multimodal/text_to_speech.md new file mode 100644 index 00000000..5416d5e9 --- /dev/null +++ b/docs/concepts/multimodal/text_to_speech.md @@ -0,0 +1,157 @@ +# Text to Speech Data Generation + +This module introduces support for multimodal data generation pipelines that accept **text** as input and produce **audio outputs** using text-to-speech (TTS) models. It expands traditional text-only pipelines to support audio generation tasks like audiobook creation, voice narration, and multi-voice dialogue generation. + +## Key Features + +- Supports **text-to-audio** generation using OpenAI TTS models. +- Converts text inputs into **base64-encoded audio data URLs** compatible with standard audio formats. +- Compatible with HuggingFace datasets, streaming, and on-disk formats. +- Supports multiple voice options and audio formats. +- Variable **speed control** (0.25x to 4.0x). +- Automatic handling of multimodal outputs with file saving capabilities. + +## Supported Models + +**Currently, we only support OpenAI TTS models:** + +- `tts-1` - Standard quality, optimized for speed +- `tts-1-hd` - High-definition quality, optimized for quality +- `gpt-4o-mini-tts` - OpenAI's newest and most reliable text-to-speech mode + +Both models support all voice options and audio formats listed below. + +## Input Requirements + +### Text Input + +Each text field to be converted to speech must: + +- Be a string containing the text to synthesize +- Not exceed **4096 characters** (OpenAI TTS limit) +- Be specified in the model configuration +- Can be local dataset or from HuggingFace datasets + +### Voice Options +You can choose from the following voices: https://platform.openai.com/docs/guides/text-to-speech#voice-options + +### Audio Formats +You can choose from the following audio formats: https://platform.openai.com/docs/guides/text-to-speech#supported-output-formats + +### Supported languages +The TTS models support multiple languages, including but not limited to: https://platform.openai.com/docs/guides/text-to-speech#supported-languages + +## How Text-to-Speech Generation Works + +1. Text input is extracted from the specified field in each record. +2. The TTS model generates audio from the text. +3. Audio is returned as a **base64-encoded data URL** (e.g., `data:audio/mp3;base64,...`). +4. The data URL is converted to a file and saved to disk. +5. The output json/jsonl gives the absolute path to the audio file. + +## Model Configuration + +The model configuration for TTS generation must specify `output_type: audio` and include TTS-specific parameters: + +```yaml +tts_openai: + model: tts + output_type: audio + model_type: azure_openai + api_version: 2025-03-01-preview + parameters: + voice: "alloy" + response_format: "wav" +``` + +## Example Configuration: Audiobook Generation + +```yaml +data_config: + source: + type: "disk" + file_path: "data/chapters.json" + +graph_config: + nodes: + generate_chapter_audio: + node_type: llm + output_keys: audio + prompt: + - user: | + "{chapter_text}" + model: + parameters: + voice: nova + response_format: mp3 + speed: 1.0 + + edges: + - from: START + to: generate_chapter_audio + - from: generate_chapter_audio + to: END + +output_config: + output_map: + id: + from: "id" + chapter_number: + from: "chapter_number" + chapter_text: + from: "chapter_text" + audio: + from: "audio" +``` + +### Input Data (`data/chapters.json`) + +```json +[ + { + "id": "1", + "chapter_number": 1, + "chapter_text": "Chapter One: The Beginning. It was a dark and stormy night..." + }, + { + "id": "2", + "chapter_number": 2, + "chapter_text": "Chapter Two: The Journey. The next morning brought clear skies..." + } +] +``` + +### Output + +```json +[ + { + "id": "1", + "chapter_number": 1, + "chapter_text": "Chapter One: The Beginning. It was a dark and stormy night...", + "audio": "/path/to/multimodal_output/audio/1_audio_0.mp3" + }, + { + "id": "2", + "chapter_number": 2, + "chapter_text": "Chapter Two: The Journey. The next morning brought clear skies...", + "audio": "/path/to/multimodal_output/audio/2_audio_0.mp3" + } +] +``` + +--- + +## Notes + +- **Text-to-speech generation is currently only supported for OpenAI TTS models.** Support for additional providers may be added in future releases. +- The output_type in model configuration must be set to `audio` to enable TTS generation. +- Audio files are automatically saved and managed. + +--- + +## See Also + +- [Audio to Text](./audio_to_text.md) - For speech recognition and audio transcription +- [Image to Text](./image_to_text.md) - For vision-based multimodal pipelines +- [OpenAI TTS Documentation](https://platform.openai.com/docs/guides/text-to-speech) - Official OpenAI TTS API reference diff --git a/mkdocs.yml b/mkdocs.yml index 597f99b5..230386dc 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -36,6 +36,7 @@ nav: - Multimodal: - Audio to Text: concepts/multimodal/audio_to_text.md - Image to Text: concepts/multimodal/image_to_text.md + - Text to Speech: concepts/multimodal/text_to_speech.md - Nodes: - Agent Node: concepts/nodes/agent_node.md - Lambda Node: concepts/nodes/lambda_node.md diff --git a/sygra/config/models.yaml b/sygra/config/models.yaml index 5cc5cc61..7cd5fb27 100644 --- a/sygra/config/models.yaml +++ b/sygra/config/models.yaml @@ -78,3 +78,14 @@ qwen3_1.7b: post_process: sygra.core.models.model_postprocessor.RemoveThinkData parameters: temperature: 0.8 + +# TTS openai model +tts_openai: + model: tts + output_type: audio # This triggers TTS functionality + model_type: azure_openai # Use azure_openai or openai model type + api_version: 2025-03-01-preview + # URL and api_key should be defined at .env file as SYGRA_TTS_OPENAI_URL and SYGRA_TTS_OPENAI_TOKEN + parameters: + voice: "alloy" + response_format: "wav" diff --git a/sygra/core/dataset/dataset_processor.py b/sygra/core/dataset/dataset_processor.py index c37275b1..6aa2ad78 100644 --- a/sygra/core/dataset/dataset_processor.py +++ b/sygra/core/dataset/dataset_processor.py @@ -3,6 +3,7 @@ import signal import time import uuid +from pathlib import Path from typing import Any, Callable, Optional, Union, cast import datasets # type: ignore[import-untyped] @@ -13,7 +14,7 @@ from sygra.core.resumable_execution import ResumableExecutionManager from sygra.data_mapper.mapper import DataMapper from sygra.logger.logger_config import logger -from sygra.utils import constants, graph_utils, utils +from sygra.utils import constants, graph_utils, multimodal_processor, utils from sygra.validators.schema_validator_base import SchemaValidator @@ -286,6 +287,17 @@ async def _write_checkpoint(self, is_oasst_mapper_required: bool) -> None: self.graph_results, self.output_record_generator ) + # Process multimodal data: save base64 data URLs to files and replace with file paths + try: + multimodal_output_dir = ".".join(self.output_file.split(".")[:-1]) + output_records = multimodal_processor.process_batch_multimodal_data( + output_records, Path(multimodal_output_dir) + ) + except Exception as e: + logger.warning( + f"Failed to process multimodal data: {e}. Continuing with original records." + ) + # Handle intermediate writing if needed if ( is_oasst_mapper_required diff --git a/sygra/core/models/client/openai_azure_client.py b/sygra/core/models/client/openai_azure_client.py index d13d5182..23d1ddfb 100644 --- a/sygra/core/models/client/openai_azure_client.py +++ b/sygra/core/models/client/openai_azure_client.py @@ -160,3 +160,41 @@ def send_request( return client.chat.completions.create(**payload, model=model_name, **generation_params) else: return client.completions.create(**payload, model=model_name, **generation_params) + + async def create_speech( + self, + model: str, + input: str, + voice: str, + response_format: str = "mp3", + speed: float = 1.0, + ) -> Any: + """ + Create speech audio from text using Azure OpenAI's text-to-speech API. + + Args: + model (str): The TTS model deployment name (e.g., 'tts-1', 'tts-1-hd') + input (str): The text to convert to speech + voice (str): The voice to use like alloy, echo, fable, onyx, nova, shimmer etc. + response_format (str, optional): The audio formats like mp3, opus, aac, flac, wav, pcm etc. Defaults to 'wav' + speed (float, optional): The speed of the audio (0.25 to 4.0). Defaults to 1.0 + + Returns: + Any: The audio response from the API + + Raises: + ValueError: If async_client is False (TTS requires async client) + """ + if not self.async_client: + raise ValueError( + "TTS API requires async client. Please initialize with async_client=True" + ) + + client = cast(Any, self.client) + return await client.audio.speech.create( + model=model, + input=input, + voice=voice, + response_format=response_format, + speed=speed, + ) diff --git a/sygra/core/models/client/openai_client.py b/sygra/core/models/client/openai_client.py index 23e160d5..7a704069 100644 --- a/sygra/core/models/client/openai_client.py +++ b/sygra/core/models/client/openai_client.py @@ -182,3 +182,41 @@ def send_request( extra_body=additional_params, **standard_params, ) + + async def create_speech( + self, + model: str, + input: str, + voice: str, + response_format: str = "mp3", + speed: float = 1.0, + ) -> Any: + """ + Create speech audio from text using OpenAI's text-to-speech API. + + Args: + model (str): The TTS model to use (e.g., 'tts-1', 'tts-1-hd') + input (str): The text to convert to speech + voice (str): The voice to use like alloy, echo, fable, onyx, nova, shimmer etc. + response_format (str, optional): The audio formats like mp3, opus, aac, flac, wav, pcm etc. Defaults to 'wav' + speed (float, optional): The speed of the audio (0.25 to 4.0). Defaults to 1.0 + + Returns: + Any: The audio response from the API + + Raises: + ValueError: If async_client is False (TTS requires async client) + """ + if not self.async_client: + raise ValueError( + "TTS API requires async client. Please initialize with async_client=True" + ) + + client = cast(Any, self.client) + return await client.audio.speech.create( + model=model, + input=input, + voice=voice, + response_format=response_format, + speed=speed, + ) diff --git a/sygra/core/models/custom_models.py b/sygra/core/models/custom_models.py index bc22191b..515f8188 100644 --- a/sygra/core/models/custom_models.py +++ b/sygra/core/models/custom_models.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import base64 import collections import json import os @@ -236,7 +237,7 @@ async def _generate_fallback_structured_output( modified_input = ChatPromptValue(messages=modified_messages) # Generate the text with retry (uses our centralized retry logic) - resp_text, resp_status = await self._generate_text_with_retry( + resp_text, resp_status = await self._generate_response_with_retry( modified_input, model_params, **kwargs ) @@ -364,7 +365,7 @@ def _update_model_stats(self, resp_text: str, resp_status: int) -> None: logger.info(f"[{self.name()}] Model Stats: {temp_model_stats}") @abstractmethod - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: pass @@ -380,7 +381,7 @@ def _ping_model(self, url, auth_token) -> int: msg = utils.backend_factory.get_test_message() # build parameters model_param = ModelParams(url=url, auth_token=auth_token) - _, status = asyncio.run(self._generate_text(msg, model_param)) + _, status = asyncio.run(self._generate_response(msg, model_param)) return status def ping(self) -> int: @@ -491,12 +492,12 @@ async def _call_with_retry( logger.info( "Structured output not configured, falling back to regular generation" ) - result = await self._generate_text(input, model_params, **kwargs) + result = await self._generate_response(input, model_params, **kwargs) else: result = so_result else: # Regular text generation - result = await self._generate_text(input, model_params, **kwargs) + result = await self._generate_response(input, model_params, **kwargs) # Apply post-processing if defined post_proc = self._get_post_processor() @@ -514,7 +515,7 @@ async def _call_with_retry( logger.error(f"[{self.name()}] Request failed after {self.retry_attempts} attempts") return result - async def _generate_text_with_retry( + async def _generate_response_with_retry( self, input: ChatPromptValue, model_params: ModelParams, **kwargs: Any ) -> Tuple[str, int]: """ @@ -694,7 +695,7 @@ async def _generate_native_structured_output( input, model_params, pydantic_model, **kwargs ) - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: try: @@ -754,7 +755,7 @@ def __init__(self, model_config: dict[str, Any]) -> None: else: raise ValueError("auth_token must be a string or non-empty list of strings") - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: model_url = model_params.url @@ -797,7 +798,7 @@ class CustomMistralAPI(BaseCustomModel): def __init__(self, model_config: dict[str, Any]) -> None: super().__init__(model_config) - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: ret_code = 200 @@ -922,7 +923,7 @@ async def _generate_native_structured_output( input, model_params, pydantic_model, **kwargs ) - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: ret_code = 200 @@ -967,6 +968,7 @@ async def _generate_text( class CustomOpenAI(BaseCustomModel): + def __init__(self, model_config: dict[str, Any]) -> None: super().__init__(model_config) utils.validate_required_keys( @@ -1056,9 +1058,29 @@ async def _generate_native_structured_output( input, model_params, pydantic_model, **kwargs ) + async def _generate_response( + self, input: ChatPromptValue, model_params: ModelParams + ) -> Tuple[str, int]: + # Check if this is a TTS request based on output_type + if self.model_config.get("output_type") == "audio": + return await self._generate_speech(input, model_params) + else: + return await self._generate_text(input, model_params) + async def _generate_text( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: + """ + Generate text using OpenAI/Azure OpenAI Chat or Completions API. + This method is called when output_type is 'text' or not specified in model config. + Args: + input: ChatPromptValue containing the messages for chat completion + model_params: Model parameters including URL and auth token + Returns: + Tuple of (response_text, status_code) + - On success: returns generated text and 200 + - On error: returns error message and error code + """ ret_code = 200 model_url = model_params.url try: @@ -1090,6 +1112,107 @@ async def _generate_text( ret_code = rcode if rcode else 999 return resp_text, ret_code + async def _generate_speech( + self, input: ChatPromptValue, model_params: ModelParams + ) -> Tuple[str, int]: + """ + Generate speech from text using OpenAI/Azure OpenAI TTS API. + This method is called when output_type is 'audio' in model config. + + Args: + input: ChatPromptValue containing the text to convert to speech + model_params: Model parameters including URL and auth token + + Returns: + Tuple of (response_text, status_code) + - On success: returns base64 encoded audio and 200 + - On error: returns error message and error code + """ + ret_code = 200 + model_url = model_params.url + + try: + # Extract text from messages + text_to_speak = "" + for message in input.messages: + if hasattr(message, "content"): + text_to_speak += str(message.content) + " " + text_to_speak = text_to_speak.strip() + + if not text_to_speak: + logger.error(f"[{self.name()}] No text provided for TTS conversion") + return f"{constants.ERROR_PREFIX} No text provided for TTS conversion", 400 + + # Validate text length (OpenAI TTS limit is 4096 characters) + if len(text_to_speak) > 4096: + logger.warn( + f"[{self.name()}] Text exceeds 4096 character limit: {len(text_to_speak)} characters" + ) + + # Set up the OpenAI client + self._set_client(model_url, model_params.auth_token) + + # Get TTS-specific parameters from generation_params or model_config + voice = self.generation_params.get("voice", self.model_config.get("voice", None)) + response_format = self.generation_params.get( + "response_format", self.model_config.get("response_format", "wav") + ) + speed = self.generation_params.get("speed", self.model_config.get("speed", 1.0)) + + # Validate speed + speed = max(0.25, min(4.0, float(speed))) + + logger.debug( + f"[{self.name()}] TTS parameters - voice: {voice}, format: {response_format}, speed: {speed}" + ) + + # Prepare TTS request parameters + tts_params = { + "input": text_to_speak, + "voice": voice, + "response_format": response_format, + "speed": speed, + } + + # Make the TTS API call + # Cast to OpenAIClient since BaseClient doesn't have create_speech + openai_client = cast(OpenAIClient, self._client) + audio_response = await openai_client.create_speech( + model=str(self.model_config.get("model")), **tts_params + ) + + # Map response format to MIME type + mime_types = { + "mp3": "audio/mpeg", + "opus": "audio/opus", + "aac": "audio/aac", + "flac": "audio/flac", + "wav": "audio/wav", + "pcm": "audio/pcm", + } + mime_type = mime_types.get(response_format, "audio/wav") + + # Create base64 encoded data URL + audio_base64 = base64.b64encode(audio_response.content).decode("utf-8") + data_url = f"data:{mime_type};base64,{audio_base64}" + resp_text = data_url + + except openai.RateLimitError as e: + logger.warning(f"[{self.name()}] OpenAI TTS API request exceeded rate limit: {e}") + resp_text = f"{constants.ERROR_PREFIX} Rate limit exceeded: {e}" + ret_code = 429 + except openai.APIError as e: + logger.error(f"[{self.name()}] OpenAI TTS API error: {e}") + resp_text = f"{constants.ERROR_PREFIX} API error: {e}" + ret_code = getattr(e, "status_code", 500) + except Exception as x: + resp_text = f"{constants.ERROR_PREFIX} TTS request failed: {x}" + logger.error(f"[{self.name()}] {resp_text}") + rcode = self._get_status_from_body(x) + ret_code = rcode if rcode else 999 + + return resp_text, ret_code + class CustomOllama(BaseCustomModel): def __init__(self, model_config: dict[str, Any]) -> None: @@ -1177,7 +1300,7 @@ async def _generate_native_structured_output( input, model_params, pydantic_model, **kwargs ) - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: ret_code = 200 @@ -1314,7 +1437,7 @@ def __init__(self, model_config: dict[str, Any]) -> None: self.model_config = model_config self.auth_token = str(model_config.get("auth_token")).replace("Bearer ", "") - async def _generate_text( + async def _generate_response( self, input: ChatPromptValue, model_params: ModelParams ) -> Tuple[str, int]: ret_code = 200 diff --git a/sygra/utils/audio_utils.py b/sygra/utils/audio_utils.py index bbb3f112..c4486bfe 100644 --- a/sygra/utils/audio_utils.py +++ b/sygra/utils/audio_utils.py @@ -2,7 +2,8 @@ import io import os import re -from typing import Any, Union, cast +from pathlib import Path +from typing import Any, Tuple, Union, cast import numpy as np import requests # type: ignore[import-untyped] @@ -218,3 +219,98 @@ def expand_audio_item(item: dict[str, Any], state: dict[str, Any]) -> list[dict[ else: expanded.append(item) return expanded + + +def parse_audio_data_url(data_url: str) -> Tuple[str, str, bytes]: + """ + Parse an audio data URL and extract MIME type, extension, and decoded content. + + Args: + data_url (str): The data URL string (e.g., "data:audio/wav;base64,...") + + Returns: + Tuple[str, str, bytes]: Tuple of (mime_type, file_extension, decoded_bytes) + + Raises: + ValueError: If the data URL format is invalid + """ + # Pattern: data:;base64, + pattern = r"^data:([^;]+);base64,(.+)$" + match = re.match(pattern, data_url) + + if not match: + raise ValueError(f"Invalid audio data URL format: {data_url[:50]}...") + + mime_type = match.group(1) + base64_data = match.group(2) + + # Decode base64 data + try: + decoded_bytes = base64.b64decode(base64_data) + except Exception as e: + raise ValueError(f"Failed to decode base64 data: {e}") + + # Determine file extension from MIME type + mime_to_ext = { + "audio/mpeg": "mp3", + "audio/mp3": "mp3", + "audio/opus": "opus", + "audio/aac": "aac", + "audio/flac": "flac", + "audio/wav": "wav", + "audio/wave": "wav", + "audio/pcm": "pcm", + "audio/ogg": "ogg", + "audio/m4a": "m4a", + "audio/aiff": "aiff", + } + + file_extension = mime_to_ext.get(mime_type, mime_type.split("/")[-1]) + + return mime_type, file_extension, decoded_bytes + + +def save_audio_data_url( + data_url: str, output_dir: Path, record_id: str, field_name: str, index: int = 0 +) -> str: + """ + Save an audio data URL to a file and return the file path. + + Args: + data_url (str): The base64 data URL to save + output_dir (Path): Directory where the file should be saved + record_id (str): ID of the record (for unique filename) + field_name (str): Name of the field containing the data + index (int): Index if the field contains multiple items (default: 0) + + Returns: + str: Relative path to the saved file + + Raises: + ValueError: If the data URL is invalid or saving fails + """ + try: + # Parse the data URL + mime_type, file_extension, decoded_bytes = parse_audio_data_url(data_url) + + # Create subdirectory for audio + audio_dir = output_dir / "audio" + audio_dir.mkdir(parents=True, exist_ok=True) + + # Create filename: record_id_fieldname_index.extension + filename = f"{record_id}_{field_name}_{index}.{file_extension}" + file_path = audio_dir / filename + + # Save the decoded bytes to file + with open(file_path, "wb") as f: + f.write(decoded_bytes) + + # full path from root + full_path = str(file_path.resolve()) + + logger.debug(f"Saved audio file: {full_path} ({len(decoded_bytes)} bytes)") + return full_path + + except Exception as e: + logger.error(f"Failed to save audio data: {e}") + raise diff --git a/sygra/utils/image_utils.py b/sygra/utils/image_utils.py index 75226f11..8c2af793 100644 --- a/sygra/utils/image_utils.py +++ b/sygra/utils/image_utils.py @@ -2,7 +2,8 @@ import io import os import re -from typing import Any, Optional +from pathlib import Path +from typing import Any, Optional, Tuple import requests # type: ignore[import-untyped] from PIL import Image @@ -182,3 +183,96 @@ def expand_image_item(item: dict[str, Any], state: dict[str, Any]) -> list[dict[ else: expanded.append(item) return expanded + + +def parse_image_data_url(data_url: str) -> Tuple[str, str, bytes]: + """ + Parse an image data URL and extract MIME type, extension, and decoded content. + + Args: + data_url (str): The data URL string (e.g., "data:image/png;base64,...") + + Returns: + Tuple[str, str, bytes]: Tuple of (mime_type, file_extension, decoded_bytes) + + Raises: + ValueError: If the data URL format is invalid + """ + # Pattern: data:;base64, + pattern = r"^data:([^;]+);base64,(.+)$" + match = re.match(pattern, data_url) + + if not match: + raise ValueError(f"Invalid image data URL format: {data_url[:50]}...") + + mime_type = match.group(1) + base64_data = match.group(2) + + # Decode base64 data + try: + decoded_bytes = base64.b64decode(base64_data) + except Exception as e: + raise ValueError(f"Failed to decode base64 data: {e}") + + # Determine file extension from MIME type + mime_to_ext = { + "image/jpeg": "jpg", + "image/jpg": "jpg", + "image/png": "png", + "image/gif": "gif", + "image/bmp": "bmp", + "image/tiff": "tiff", + "image/tif": "tif", + "image/webp": "webp", + "image/ico": "ico", + "image/apng": "apng", + } + + file_extension = mime_to_ext.get(mime_type, mime_type.split("/")[-1]) + + return mime_type, file_extension, decoded_bytes + + +def save_image_data_url( + data_url: str, output_dir: Path, record_id: str, field_name: str, index: int = 0 +) -> str: + """ + Save an image data URL to a file and return the file path. + + Args: + data_url (str): The base64 data URL to save + output_dir (Path): Directory where the file should be saved + record_id (str): ID of the record (for unique filename) + field_name (str): Name of the field containing the data + index (int): Index if the field contains multiple items (default: 0) + + Returns: + str: Relative path to the saved file + + Raises: + ValueError: If the data URL is invalid or saving fails + """ + try: + # Parse the data URL + mime_type, file_extension, decoded_bytes = parse_image_data_url(data_url) + + # Create subdirectory for images + image_dir = output_dir / "image" + image_dir.mkdir(parents=True, exist_ok=True) + + # Create filename: record_id_fieldname_index.extension + filename = f"{record_id}_{field_name}_{index}.{file_extension}" + file_path = image_dir / filename + + # Save the decoded bytes to file + with open(file_path, "wb") as f: + f.write(decoded_bytes) + + full_path = str(file_path.resolve()) + + logger.debug(f"Saved image file: {full_path} ({len(decoded_bytes)} bytes)") + return full_path + + except Exception as e: + logger.error(f"Failed to save image data: {e}") + raise diff --git a/sygra/utils/multimodal_processor.py b/sygra/utils/multimodal_processor.py new file mode 100644 index 00000000..3d50e5ce --- /dev/null +++ b/sygra/utils/multimodal_processor.py @@ -0,0 +1,132 @@ +""" +Utility for processing multimodal data (audio and images) in records. +This module orchestrates the use of audio_utils and image_utils to save base64 data URLs to files. +""" + +from pathlib import Path +from typing import Any, Dict, List + +from sygra.logger.logger_config import logger +from sygra.utils import audio_utils, image_utils + + +def is_multimodal_data_url(value: Any) -> bool: + """ + Check if a value is a base64 encoded data URL for audio or image. + + Args: + value: The value to check + + Returns: + bool: True if the value is a multimodal data URL, False otherwise + """ + if not isinstance(value, str): + return False + return image_utils.is_data_url(value) or audio_utils.is_data_url(value) + + +def save_multimodal_data_url( + data_url: str, output_dir: Path, record_id: str, field_name: str, index: int = 0 +) -> str: + """ + Save a multimodal data URL (audio or image) to a file and return the file path. + + Args: + data_url: The base64 data URL to save + output_dir: Directory where the file should be saved + record_id: ID of the record (for unique filename) + field_name: Name of the field containing the data + index: Index if the field contains multiple items (default: 0) + + Returns: + str: Relative path to the saved file + + Raises: + ValueError: If the data URL is invalid or saving fails + """ + if image_utils.is_data_url(data_url): + return image_utils.save_image_data_url(data_url, output_dir, record_id, field_name, index) + elif audio_utils.is_data_url(data_url): + return audio_utils.save_audio_data_url(data_url, output_dir, record_id, field_name, index) + else: + raise ValueError(f"Unsupported data URL type: {data_url[:50]}...") + + +def process_record_multimodal_data( + record: Dict[str, Any], output_dir: Path, record_id: str +) -> Dict[str, Any]: + """ + Process a record and replace all base64 data URLs with file paths. + + This function recursively searches through the record structure (including nested + dicts and lists) and replaces any multimodal data URLs with file paths. + + Args: + record: The record to process + output_dir: Directory where files should be saved + record_id: ID of the record (for unique filenames) + + Returns: + Dict[str, Any]: The processed record with data URLs replaced by file paths + """ + + def process_value(value: Any, field_name: str, index: int = 0) -> Any: + """Recursively process values in the record.""" + if is_multimodal_data_url(value): + # Save the data URL to file and return the file path + try: + file_path = save_multimodal_data_url( + value, output_dir, record_id, field_name, index + ) + return file_path + except Exception as e: + logger.warning(f"Failed to process data URL in field '{field_name}': {e}") + return value + + elif isinstance(value, dict): + return {k: process_value(v, f"{field_name}_{k}", 0) for k, v in value.items()} + + elif isinstance(value, list): + return [process_value(item, field_name, idx) for idx, item in enumerate(value)] + + else: + # Return value as-is if it's not a data URL, dict, or list + return value + + # Process the entire record + processed_record = {} + for key, value in record.items(): + processed_record[key] = process_value(value, key) + + return processed_record + + +def process_batch_multimodal_data( + records: list[Dict[str, Any]], output_dir: Path +) -> list[Dict[str, Any]]: + """ + Process a batch of records and save all multimodal data to files. + + Args: + records: List of records to process + output_dir: Directory where multimodal files should be saved + + Returns: + list[Dict[str, Any]]: List of processed records with data URLs replaced by file paths + """ + if not records: + return records + + # Create multimodal output directory + output_dir.mkdir(parents=True, exist_ok=True) + + processed_records: List[Dict[str, Any]] = [] + for record in records: + # Use record ID if available, otherwise use index + record_id = str(record.get("id", f"record_{len(processed_records)}")) + + processed_record = process_record_multimodal_data(record, output_dir, record_id) + processed_records.append(processed_record) + + logger.info(f"Processed {len(records)} records, saved multimodal files to {output_dir}") + return processed_records diff --git a/tests/core/models/test_custom_azure.py b/tests/core/models/test_custom_azure.py new file mode 100644 index 00000000..fb543ea7 --- /dev/null +++ b/tests/core/models/test_custom_azure.py @@ -0,0 +1,351 @@ +import json +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Add the parent directory to sys.path to import the necessary modules +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompt_values import ChatPromptValue + +from sygra.core.models.custom_models import CustomAzure, ModelParams + + +class TestCustomAzure(unittest.TestCase): + """Unit tests for the CustomAzure class""" + + def setUp(self): + """Set up test fixtures before each test method""" + # Base model configuration with string auth_token + self.base_config = { + "name": "azure_model", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + "url": "https://azure-test.openai.azure.com", + "auth_token": "Bearer test_token_123", + } + + # Configuration with list of auth_tokens + self.multi_token_config = { + "name": "azure_model_multi", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + "url": "https://azure-test.openai.azure.com", + "auth_token": ["Bearer token1", "Bearer token2", "Bearer token3"], + } + + # Mock messages + self.messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="Hello, how are you?"), + ] + self.chat_input = ChatPromptValue(messages=self.messages) + + def test_init_with_string_auth_token(self): + """Test initialization with a single string auth_token""" + custom_azure = CustomAzure(self.base_config) + + # Verify model was properly initialized + self.assertEqual(custom_azure.model_config, self.base_config) + self.assertEqual(custom_azure.generation_params, self.base_config["parameters"]) + self.assertEqual(custom_azure.auth_token, "test_token_123") # Bearer prefix removed + self.assertEqual(custom_azure.name(), "azure_model") + + def test_init_with_list_auth_token(self): + """Test initialization with a list of auth_tokens""" + custom_azure = CustomAzure(self.multi_token_config) + + # Verify model was properly initialized with first token + self.assertEqual(custom_azure.model_config, self.multi_token_config) + self.assertEqual(custom_azure.auth_token, "token1") # First token, Bearer prefix removed + self.assertEqual(custom_azure.name(), "azure_model_multi") + + def test_init_with_empty_list_raises_error(self): + """Test initialization with an empty list raises ValueError""" + config = {**self.base_config, "auth_token": []} + + with self.assertRaises(ValueError) as context: + CustomAzure(config) + + self.assertIn("auth_token must be a string or non-empty list", str(context.exception)) + + def test_init_with_invalid_type_raises_error(self): + """Test initialization with invalid auth_token type raises ValueError""" + config = {**self.base_config, "auth_token": 12345} + + with self.assertRaises(ValueError) as context: + CustomAzure(config) + + self.assertIn("auth_token must be a string or non-empty list", str(context.exception)) + + def test_init_with_list_containing_non_string_raises_error(self): + """Test initialization with list containing non-string raises TypeError""" + config = {**self.base_config, "auth_token": [123, "token"]} + + with self.assertRaises(TypeError) as context: + CustomAzure(config) + + self.assertIn("auth_token list must contain strings", str(context.exception)) + + def test_init_missing_url_raises_error(self): + """Test initialization without url raises error""" + config = { + "name": "azure_model", + "parameters": {"temperature": 0.7}, + "auth_token": "test_token", + } + + with self.assertRaises(Exception): + CustomAzure(config) + + def test_init_missing_auth_token_raises_error(self): + """Test initialization without auth_token raises error""" + config = { + "name": "azure_model", + "parameters": {"temperature": 0.7}, + "url": "https://azure-test.openai.azure.com", + } + + with self.assertRaises(Exception): + CustomAzure(config) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_success(self, mock_utils, mock_set_client): + """Test _generate_response method with successful response""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = { + "messages": [{"role": "user", "content": "Hello"}] + } + + # Configure mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps( + { + "choices": [ + { + "message": {"content": "Hello! I'm doing well, thank you!"}, + "finish_reason": "stop", + } + ] + } + ) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello, how are you?"}, + ] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="test_token" + ) + resp_text, resp_status = await custom_azure._generate_response( + self.chat_input, model_params + ) + + # Verify results + self.assertEqual(resp_text, "Hello! I'm doing well, thank you!") + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once() + mock_client.build_request_with_payload.assert_called_once() + mock_client.async_send_request.assert_awaited_once() + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_content_filter(self, mock_utils, mock_set_client): + """Test _generate_response method with content filter response""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = { + "messages": [{"role": "user", "content": "Hello"}] + } + + # Configure mock response with content filter + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps( + {"choices": [{"message": {"content": ""}, "finish_reason": "content_filter"}]} + ) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="test_token" + ) + resp_text, resp_status = await custom_azure._generate_response( + self.chat_input, model_params + ) + + # Verify results - should return content filter message with code 444 + self.assertEqual(resp_text, "Blocked by azure content filter") + self.assertEqual(resp_status, 444) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_http_error(self, mock_utils, mock_set_client, mock_logger): + """Test _generate_response method with HTTP error""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = { + "messages": [{"role": "user", "content": "Hello"}] + } + + # Configure mock response with error + mock_response = MagicMock() + mock_response.status_code = 429 + mock_response.text = "Rate limit exceeded" + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="test_token" + ) + resp_text, resp_status = await custom_azure._generate_response( + self.chat_input, model_params + ) + + # Verify results - should return empty string with error status + self.assertEqual(resp_text, "") + self.assertEqual(resp_status, 429) + + # Verify error logging + mock_logger.error.assert_called() + self.assertIn("HTTP request failed", str(mock_logger.error.call_args)) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_exception(self, mock_utils, mock_set_client, mock_logger): + """Test _generate_response method with exception""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request_with_payload.side_effect = Exception("Connection timeout") + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + custom_azure._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="test_token" + ) + resp_text, resp_status = await custom_azure._generate_response( + self.chat_input, model_params + ) + + # Verify results + self.assertIn("Http request failed", resp_text) + self.assertIn("Connection timeout", resp_text) + self.assertEqual(resp_status, 999) + + # Verify error logging + mock_logger.error.assert_called() + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_with_extracted_status_code(self, mock_utils, mock_set_client): + """Test _generate_response extracts status code from error body""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request_with_payload.side_effect = Exception("Service unavailable") + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + custom_azure._get_status_from_body = MagicMock(return_value=503) + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="test_token" + ) + resp_text, resp_status = await custom_azure._generate_response( + self.chat_input, model_params + ) + + # Verify extracted status code is used + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_set_client_called_with_correct_params(self, mock_utils, mock_set_client): + """Test that _set_client is called with correct parameters""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"messages": []} + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps( + {"choices": [{"message": {"content": "Response"}, "finish_reason": "stop"}]} + ) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [] + + # Setup custom model + custom_azure = CustomAzure(self.base_config) + custom_azure._client = mock_client + + # Call _generate_response + model_params = ModelParams( + url="https://azure-test.openai.azure.com", auth_token="custom_token" + ) + await custom_azure._generate_response(self.chat_input, model_params) + + # Verify _set_client was called with model params + mock_set_client.assert_called_once_with( + "https://azure-test.openai.azure.com", "custom_token" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/models/test_custom_mistralapi.py b/tests/core/models/test_custom_mistralapi.py new file mode 100644 index 00000000..3657c840 --- /dev/null +++ b/tests/core/models/test_custom_mistralapi.py @@ -0,0 +1,357 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Add the parent directory to sys.path to import the necessary modules +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompt_values import ChatPromptValue + +from sygra.core.models.custom_models import CustomMistralAPI, ModelParams +from sygra.utils import constants + + +class TestCustomMistralAPI(unittest.TestCase): + """Unit tests for the CustomMistralAPI class""" + + def setUp(self): + """Set up test fixtures before each test method""" + # Base model configuration + self.base_config = { + "name": "mistral_model", + "model": "mistral-large-latest", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + } + + # Mock messages + self.messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="Hello, how are you?"), + ] + self.chat_input = ChatPromptValue(messages=self.messages) + + def test_init(self): + """Test initialization of CustomMistralAPI""" + custom_mistral = CustomMistralAPI(self.base_config) + + # Verify model was properly initialized + self.assertEqual(custom_mistral.model_config, self.base_config) + self.assertEqual(custom_mistral.generation_params, self.base_config["parameters"]) + self.assertEqual(custom_mistral.name(), "mistral_model") + + def test_init_with_minimal_config(self): + """Test initialization with minimal configuration""" + minimal_config = { + "name": "mistral_minimal", + "parameters": {}, + } + custom_mistral = CustomMistralAPI(minimal_config) + + self.assertEqual(custom_mistral.name(), "mistral_minimal") + self.assertEqual(custom_mistral.generation_params, {}) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_success(self, mock_utils, mock_set_client): + """Test _generate_response method with successful response""" + # Setup mock client + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + # Setup mock response + mock_message = MagicMock() + mock_message.content = "Hello! I'm doing great, thank you for asking!" + + mock_choice = MagicMock() + mock_choice.message = mock_message + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_chat.complete_async = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello, how are you?"}, + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + resp_text, resp_status = await custom_mistral._generate_response( + self.chat_input, model_params + ) + + # Verify results + self.assertEqual(resp_text, "Hello! I'm doing great, thank you for asking!") + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once_with("https://api.mistral.ai", "test_token") + mock_chat.complete_async.assert_awaited_once() + + # Verify the messages passed to the API + call_args = mock_chat.complete_async.call_args + self.assertEqual(call_args.kwargs["model"], "mistral-large-latest") + self.assertEqual(len(call_args.kwargs["messages"]), 2) + self.assertEqual(call_args.kwargs["messages"][0]["role"], "system") + self.assertEqual(call_args.kwargs["messages"][1]["role"], "user") + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_rate_limit_error( + self, mock_utils, mock_set_client, mock_logger + ): + """Test _generate_response method with rate limit error""" + # Setup mock client to raise rate limit exception + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + error_msg = "Status 429: Too many requests" + mock_chat.complete_async = AsyncMock(side_effect=Exception(error_msg)) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + custom_mistral._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + resp_text, resp_status = await custom_mistral._generate_response( + self.chat_input, model_params + ) + + # Verify results - should return 429 for rate limit + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Http request failed", resp_text) + self.assertEqual(resp_status, 429) + + # Verify error logging + mock_logger.error.assert_called() + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_model_overload_error( + self, mock_utils, mock_set_client, mock_logger + ): + """Test _generate_response method with model overload error""" + # Setup mock client to raise model overload exception + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + error_msg = "Model is overloaded" + mock_chat.complete_async = AsyncMock(side_effect=Exception(error_msg)) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + custom_mistral._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + resp_text, resp_status = await custom_mistral._generate_response( + self.chat_input, model_params + ) + + # Verify results - should return 429 for model overload + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertEqual(resp_status, 429) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_generic_exception( + self, mock_utils, mock_set_client, mock_logger + ): + """Test _generate_response method with generic exception""" + # Setup mock client to raise generic exception + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + error_msg = "Connection timeout" + mock_chat.complete_async = AsyncMock(side_effect=Exception(error_msg)) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + custom_mistral._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + resp_text, resp_status = await custom_mistral._generate_response( + self.chat_input, model_params + ) + + # Verify results - should return 999 for generic error + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Connection timeout", resp_text) + self.assertEqual(resp_status, 999) + + # Verify error logging + mock_logger.error.assert_called() + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_with_extracted_status_code( + self, mock_utils, mock_set_client, mock_logger + ): + """Test _generate_response extracts status code from error body""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + error_msg = "Service unavailable" + mock_chat.complete_async = AsyncMock(side_effect=Exception(error_msg)) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + custom_mistral._get_status_from_body = MagicMock(return_value=503) + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + resp_text, resp_status = await custom_mistral._generate_response( + self.chat_input, model_params + ) + + # Verify extracted status code is used + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_generate_response_with_generation_params(self, mock_utils, mock_set_client): + """Test _generate_response passes generation parameters correctly""" + # Setup mock client + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + mock_message = MagicMock() + mock_message.content = "Response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_chat.complete_async = AsyncMock(return_value=mock_response) + + # Mock utils methods + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "user", "content": "Hello"} + ] + + # Setup custom model with specific generation params + config = { + **self.base_config, + "parameters": { + "temperature": 0.9, + "max_tokens": 500, + "top_p": 0.95, + }, + } + custom_mistral = CustomMistralAPI(config) + custom_mistral._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + await custom_mistral._generate_response(self.chat_input, model_params) + + # Verify generation parameters were passed + call_args = mock_chat.complete_async.call_args + self.assertEqual(call_args.kwargs["temperature"], 0.9) + self.assertEqual(call_args.kwargs["max_tokens"], 500) + self.assertEqual(call_args.kwargs["top_p"], 0.95) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.utils") + @pytest.mark.asyncio + async def test_messages_format_conversion(self, mock_utils, mock_set_client): + """Test that messages are properly converted to role/content format""" + # Setup mock client + mock_client = MagicMock() + mock_chat = MagicMock() + mock_client.chat = mock_chat + + mock_message = MagicMock() + mock_message.content = "Response" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + + mock_chat.complete_async = AsyncMock(return_value=mock_response) + + # Mock utils to return messages with extra fields + mock_utils.convert_messages_from_langchain_to_chat_format.return_value = [ + {"role": "system", "content": "System prompt", "extra_field": "should_be_removed"}, + {"role": "user", "content": "User message", "another_field": "also_removed"}, + ] + + # Setup custom model + custom_mistral = CustomMistralAPI(self.base_config) + custom_mistral._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="https://api.mistral.ai", auth_token="test_token") + await custom_mistral._generate_response(self.chat_input, model_params) + + # Verify only role and content are passed to API + call_args = mock_chat.complete_async.call_args + messages = call_args.kwargs["messages"] + self.assertEqual(len(messages), 2) + + # Check first message + self.assertEqual(messages[0]["role"], "system") + self.assertEqual(messages[0]["content"], "System prompt") + self.assertNotIn("extra_field", messages[0]) + + # Check second message + self.assertEqual(messages[1]["role"], "user") + self.assertEqual(messages[1]["content"], "User message") + self.assertNotIn("another_field", messages[1]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/models/test_custom_models_completion_api.py b/tests/core/models/test_custom_models_completion_api.py index aa048a6b..23d3fb18 100644 --- a/tests/core/models/test_custom_models_completion_api.py +++ b/tests/core/models/test_custom_models_completion_api.py @@ -72,7 +72,7 @@ def test_base_model_completions_api_not_supported(self, mock_logger, mock_client # Create a test implementation of BaseCustomModel class TestBaseModel(BaseCustomModel): - def _generate_text(self, *args, **kwargs): + def _generate_response(self, *args, **kwargs): pass def name(self): @@ -97,7 +97,7 @@ def test_base_model_completions_api_not_set(self, mock_logger, mock_client_facto # Create a test implementation of BaseCustomModel class TestBaseModel(BaseCustomModel): - def _generate_text(self, *args, **kwargs): + def _generate_response(self, *args, **kwargs): pass def name(self): diff --git a/tests/core/models/test_ollama_custom_model.py b/tests/core/models/test_custom_ollama.py similarity index 92% rename from tests/core/models/test_ollama_custom_model.py rename to tests/core/models/test_custom_ollama.py index eb182b58..5c2f60fe 100644 --- a/tests/core/models/test_ollama_custom_model.py +++ b/tests/core/models/test_custom_ollama.py @@ -76,8 +76,8 @@ def test_validate_completions_api_support(self, mock_logger): self.assertTrue(custom_ollama.model_config.get("completions_api")) @pytest.mark.asyncio - async def test_generate_text_chat_completions(self): - """Test _generate_text method with chat completions API""" + async def test_generate_response_chat_completions(self): + """Test _generate_response method with chat completions API""" # Setup mock client mock_client = MagicMock() mock_client.build_request.return_value = { @@ -89,9 +89,11 @@ async def test_generate_text_chat_completions(self): custom_ollama = CustomOllama(self.base_config) custom_ollama._client = mock_client - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://localhost:11434") - resp_text, resp_status = await custom_ollama._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_ollama._generate_response( + self.chat_input, model_params + ) # Verify results self.assertEqual(resp_text, "Hello there!") @@ -110,10 +112,10 @@ async def test_generate_text_chat_completions(self): @patch("sygra.core.models.custom_models.BaseCustomModel._finalize_response") @patch("sygra.core.models.custom_models.BaseCustomModel.get_chat_formatted_text") @pytest.mark.asyncio - async def test_generate_text_completions_api( + async def test_generate_response_completions_api( self, mock_get_formatted, mock_finalize, mock_set_client, mock_client_factory ): - """Test _generate_text method with completions API""" + """Test _generate_response method with completions API""" # Setup mock client mock_client = MagicMock() mock_client.build_request.return_value = {"prompt": "Hello, how are you?"} @@ -128,9 +130,11 @@ async def test_generate_text_completions_api( # Mock the get_chat_formatted_text method mock_get_formatted.return_value = "Hello, how are you?" - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://localhost:11434") - resp_text, resp_status = await custom_ollama._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_ollama._generate_response( + self.chat_input, model_params + ) # Verify results self.assertEqual(resp_text, "I'm doing well, thank you!") @@ -148,10 +152,10 @@ async def test_generate_text_completions_api( @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") @patch("sygra.core.models.custom_models.BaseCustomModel._finalize_response") @pytest.mark.asyncio - async def test_generate_text_exception( + async def test_generate_response_exception( self, mock_finalize, mock_set_client, mock_client_factory ): - """Test _generate_text method with an exception""" + """Test _generate_response method with an exception""" # Setup mock client to raise an exception mock_client = MagicMock() mock_client.build_request.return_value = { @@ -163,9 +167,11 @@ async def test_generate_text_exception( custom_ollama = CustomOllama(self.base_config) custom_ollama._client = mock_client - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://localhost:11434") - resp_text, resp_status = await custom_ollama._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_ollama._generate_response( + self.chat_input, model_params + ) # Verify error handling self.assertTrue(resp_text.startswith(f"{constants.ERROR_PREFIX} Ollama request failed")) diff --git a/tests/core/models/test_custom_openai.py b/tests/core/models/test_custom_openai.py new file mode 100644 index 00000000..6afe07c6 --- /dev/null +++ b/tests/core/models/test_custom_openai.py @@ -0,0 +1,496 @@ +import base64 +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, mock_open, patch + +import openai +import pytest + +# Add the parent directory to sys.path to import the necessary modules +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompt_values import ChatPromptValue + +from sygra.core.models.custom_models import CustomOpenAI, ModelParams +from sygra.utils import constants + + +class TestCustomOpenAI(unittest.TestCase): + """Unit tests for the CustomOpenAI class - model level tests""" + + def setUp(self): + """Set up test fixtures before each test method""" + # Base model configuration for text generation + self.text_config = { + "name": "gpt4_model", + "model": "gpt-4", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + "url": "https://api.openai.com/v1", + "auth_token": "Bearer sk-test_key_123", + "api_version": "2023-05-15", + } + + # Configuration for TTS + self.tts_config = { + "name": "tts_model", + "model": "tts-1", + "output_type": "audio", + "parameters": {}, + "url": "https://api.openai.com/v1", + "auth_token": "Bearer sk-test_key_123", + "api_version": "2023-05-15", + "voice": "alloy", + "response_format": "mp3", + "speed": 1.0, + } + + # Configuration with completions API + self.completions_config = { + **self.text_config, + "completions_api": True, + "hf_chat_template_model_id": "meta-llama/Llama-2-7b-chat-hf", + } + + # Mock messages for text generation + self.messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="Hello, how are you?"), + ] + self.chat_input = ChatPromptValue(messages=self.messages) + + # Mock messages for TTS + self.tts_messages = [HumanMessage(content="Hello, this is a test of text to speech.")] + self.tts_input = ChatPromptValue(messages=self.tts_messages) + + def test_init(self): + """Test initialization of CustomOpenAI""" + custom_openai = CustomOpenAI(self.text_config) + + # Verify model was properly initialized + self.assertEqual(custom_openai.model_config, self.text_config) + self.assertEqual(custom_openai.generation_params, self.text_config["parameters"]) + self.assertEqual(custom_openai.name(), "gpt4_model") + + def test_init_missing_required_keys_raises_error(self): + """Test initialization without required keys raises error""" + config = { + "name": "gpt4_model", + "parameters": {"temperature": 0.7}, + } + + with self.assertRaises(Exception): + CustomOpenAI(config) + + @patch("sygra.core.models.custom_models.logger") + def test_init_with_completions_api(self, mock_logger): + """Test initialization with completions API enabled""" + with patch("sygra.core.models.custom_models.AutoTokenizer"): + custom_openai = CustomOpenAI(self.completions_config) + + self.assertTrue(custom_openai.model_config.get("completions_api")) + # Should log that model supports completion API + mock_logger.info.assert_any_call("Model gpt4_model supports completion API.") + + # ============== _generate_text Tests ============== + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_text_chat_api_success(self, mock_set_client): + """Test _generate_text with chat API (non-completions)""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = { + "messages": [{"role": "user", "content": "Hello"}] + } + + # Setup mock completion response + mock_choice = MagicMock() + mock_choice.model_dump.return_value = { + "message": {"content": " Hello! I'm doing well, thank you! "} + } + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model + custom_openai = CustomOpenAI(self.text_config) + custom_openai._client = mock_client + + # Call _generate_text + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_text(self.chat_input, model_params) + + # Verify results (text should be stripped) + self.assertEqual(resp_text, "Hello! I'm doing well, thank you!") + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once() + mock_client.build_request.assert_called_once_with(messages=self.messages) + mock_client.send_request.assert_awaited_once() + + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_text_completions_api_success(self, mock_set_client, mock_tokenizer): + """Test _generate_text with completions API""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = {"prompt": "Formatted prompt"} + + # Setup mock completion response for completions API + mock_choice = MagicMock() + mock_choice.model_dump.return_value = {"text": " Response text "} + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model with completions API + custom_openai = CustomOpenAI(self.completions_config) + custom_openai._client = mock_client + custom_openai.get_chat_formatted_text = MagicMock(return_value="Formatted prompt") + + # Call _generate_text + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_text(self.chat_input, model_params) + + # Verify results (text should be stripped) + self.assertEqual(resp_text, "Response text") + self.assertEqual(resp_status, 200) + + # Verify completions API path was used + custom_openai.get_chat_formatted_text.assert_called_once() + mock_client.build_request.assert_called_once_with(formatted_prompt="Formatted prompt") + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_text_rate_limit_error(self, mock_set_client, mock_logger): + """Test _generate_text with rate limit error""" + # Setup mock client to raise RateLimitError + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock( + side_effect=openai.RateLimitError( + "Rate limit exceeded", response=MagicMock(), body=None + ) + ) + + # Setup custom model + custom_openai = CustomOpenAI(self.text_config) + custom_openai._client = mock_client + + # Call _generate_text + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_text(self.chat_input, model_params) + + # Verify results - should return 429 for rate limit + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertEqual(resp_status, 429) + + # Verify warning logging + mock_logger.warn.assert_called() + self.assertIn("rate limit", str(mock_logger.warn.call_args)) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_text_generic_exception(self, mock_set_client, mock_logger): + """Test _generate_text with generic exception""" + # Setup mock client to raise generic exception + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock(side_effect=Exception("Network timeout")) + + # Setup custom model + custom_openai = CustomOpenAI(self.text_config) + custom_openai._client = mock_client + custom_openai._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_text + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_text(self.chat_input, model_params) + + # Verify results - should return 999 for generic error + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Network timeout", resp_text) + self.assertEqual(resp_status, 999) + + # ============== _generate_speech Tests ============== + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_success_base64(self, mock_set_client): + """Test _generate_speech returns base64 encoded audio when no output_file""" + # Setup mock client + mock_client = MagicMock() + mock_audio_content = b"fake_audio_data" + mock_response = MagicMock() + mock_response.content = mock_audio_content + + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_openai = CustomOpenAI(self.tts_config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify results + expected_base64 = base64.b64encode(mock_audio_content).decode("utf-8") + self.assertEqual(resp_text, expected_base64) + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once() + mock_client.create_speech.assert_awaited_once() + + @patch("builtins.open", new_callable=mock_open) + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_success_file_output(self, mock_set_client, mock_file): + """Test _generate_speech saves to file when output_file specified""" + # Setup mock client + mock_client = MagicMock() + mock_audio_content = b"fake_audio_data" + mock_response = MagicMock() + mock_response.content = mock_audio_content + + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Setup custom model with output file + config = {**self.tts_config, "output_file": "/tmp/output.mp3"} + custom_openai = CustomOpenAI(config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify results + self.assertIn("Audio successfully saved", resp_text) + self.assertIn("/tmp/output.mp3", resp_text) + self.assertEqual(resp_status, 200) + + # Verify file was written + mock_file.assert_called_once_with("/tmp/output.mp3", "wb") + mock_file().write.assert_called_once_with(mock_audio_content) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_empty_text(self, mock_set_client, mock_logger): + """Test _generate_speech with empty text""" + # Setup custom model + custom_openai = CustomOpenAI(self.tts_config) + + # Create empty input + empty_input = ChatPromptValue(messages=[HumanMessage(content="")]) + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(empty_input, model_params) + + # Verify results + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("No text provided", resp_text) + self.assertEqual(resp_status, 400) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_text_too_long(self, mock_set_client, mock_logger): + """Test _generate_speech with text exceeding 4096 character limit""" + # Setup custom model + custom_openai = CustomOpenAI(self.tts_config) + + # Create input with text > 4096 characters + long_text = "A" * 5000 + long_input = ChatPromptValue(messages=[HumanMessage(content=long_text)]) + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(long_input, model_params) + + # Verify results + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("exceeds 4096 character limit", resp_text) + self.assertEqual(resp_status, 400) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_invalid_voice(self, mock_set_client, mock_logger): + """Test _generate_speech with invalid voice falls back to default""" + # Setup mock client + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.content = b"audio_data" + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Setup custom model with invalid voice + config = {**self.tts_config, "voice": "invalid_voice"} + custom_openai = CustomOpenAI(config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify warning was logged and default 'alloy' was used + mock_logger.warning.assert_called() + self.assertIn("Invalid voice", str(mock_logger.warning.call_args)) + + # Verify create_speech was called with 'alloy' + call_args = mock_client.create_speech.call_args + self.assertEqual(call_args.kwargs["voice"], "alloy") + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_invalid_format(self, mock_set_client, mock_logger): + """Test _generate_speech with invalid format falls back to default""" + # Setup mock client + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.content = b"audio_data" + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Setup custom model with invalid format + config = {**self.tts_config, "response_format": "invalid_format"} + custom_openai = CustomOpenAI(config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify warning was logged and default 'mp3' was used + mock_logger.warning.assert_called() + self.assertIn("Invalid format", str(mock_logger.warning.call_args)) + + # Verify create_speech was called with 'mp3' + call_args = mock_client.create_speech.call_args + self.assertEqual(call_args.kwargs["response_format"], "mp3") + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_speed_clamping(self, mock_set_client): + """Test _generate_speech clamps speed to valid range""" + # Setup mock client + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.content = b"audio_data" + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Test speed too low + config_low = {**self.tts_config, "speed": 0.1} + custom_openai_low = CustomOpenAI(config_low) + custom_openai_low._client = mock_client + + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + await custom_openai_low._generate_speech(self.tts_input, model_params) + + # Verify speed was clamped to 0.25 + call_args = mock_client.create_speech.call_args + self.assertEqual(call_args.kwargs["speed"], 0.25) + + # Test speed too high + config_high = {**self.tts_config, "speed": 5.0} + custom_openai_high = CustomOpenAI(config_high) + custom_openai_high._client = mock_client + + await custom_openai_high._generate_speech(self.tts_input, model_params) + + # Verify speed was clamped to 4.0 + call_args = mock_client.create_speech.call_args + self.assertEqual(call_args.kwargs["speed"], 4.0) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_rate_limit_error(self, mock_set_client, mock_logger): + """Test _generate_speech with rate limit error""" + # Setup mock client to raise RateLimitError + mock_client = MagicMock() + mock_client.create_speech = AsyncMock( + side_effect=openai.RateLimitError( + "Rate limit exceeded", response=MagicMock(), body=None + ) + ) + + # Setup custom model + custom_openai = CustomOpenAI(self.tts_config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify results + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Rate limit exceeded", resp_text) + self.assertEqual(resp_status, 429) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_speech_api_error(self, mock_set_client, mock_logger): + """Test _generate_speech with API error""" + # Setup mock client to raise APIError + mock_response = MagicMock() + mock_response.status_code = 500 + api_error = openai.APIError("Internal server error", response=mock_response, body=None) + api_error.status_code = 500 + + mock_client = MagicMock() + mock_client.create_speech = AsyncMock(side_effect=api_error) + + # Setup custom model + custom_openai = CustomOpenAI(self.tts_config) + custom_openai._client = mock_client + + # Call _generate_speech + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_speech(self.tts_input, model_params) + + # Verify results + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("API error", resp_text) + self.assertEqual(resp_status, 500) + + # ============== _generate_response Routing Tests ============== + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_routes_to_speech(self, mock_set_client): + """Test _generate_response routes to _generate_speech for audio output""" + # Setup mock client + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.content = b"audio_data" + mock_client.create_speech = AsyncMock(return_value=mock_response) + + # Setup custom model with audio output type + custom_openai = CustomOpenAI(self.tts_config) + custom_openai._client = mock_client + + # Call _generate_response (should route to _generate_speech) + model_params = ModelParams(url="https://api.openai.com/v1", auth_token="sk-test") + resp_text, resp_status = await custom_openai._generate_response( + self.tts_input, model_params + ) + + # Verify it called create_speech (TTS path) + mock_client.create_speech.assert_awaited_once() + self.assertEqual(resp_status, 200) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/models/test_custom_tgi.py b/tests/core/models/test_custom_tgi.py new file mode 100644 index 00000000..605ec704 --- /dev/null +++ b/tests/core/models/test_custom_tgi.py @@ -0,0 +1,435 @@ +import json +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Add the parent directory to sys.path to import the necessary modules +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompt_values import ChatPromptValue +from pydantic import BaseModel + +from sygra.core.models.custom_models import CustomTGI, ModelParams +from sygra.utils import constants + + +class TestCustomTGI(unittest.TestCase): + """Unit tests for the CustomTGI class""" + + def setUp(self): + """Set up test fixtures before each test method""" + # Base model configuration + self.base_config = { + "name": "tgi_model", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + "url": "http://tgi-test.com", + "auth_token": "Bearer test_token_123", + "hf_chat_template_model_id": "meta-llama/Llama-2-7b-chat-hf", + } + + # Mock messages + self.messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="Hello, how are you?"), + ] + self.chat_input = ChatPromptValue(messages=self.messages) + + def test_init(self): + """Test initialization of CustomTGI""" + with patch("sygra.core.models.custom_models.AutoTokenizer"): + custom_tgi = CustomTGI(self.base_config) + + # Verify model was properly initialized + self.assertEqual(custom_tgi.model_config, self.base_config) + self.assertEqual(custom_tgi.generation_params, self.base_config["parameters"]) + self.assertEqual(custom_tgi.auth_token, "test_token_123") # Bearer prefix removed + self.assertEqual(custom_tgi.name(), "tgi_model") + + def test_init_missing_url_raises_error(self): + """Test initialization without url raises error""" + config = { + "name": "tgi_model", + "parameters": {"temperature": 0.7}, + "auth_token": "test_token", + } + + with self.assertRaises(Exception): + CustomTGI(config) + + def test_init_missing_auth_token_raises_error(self): + """Test initialization without auth_token raises error""" + config = { + "name": "tgi_model", + "parameters": {"temperature": 0.7}, + "url": "http://tgi-test.com", + } + + with self.assertRaises(Exception): + CustomTGI(config) + + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_success(self, mock_set_client, mock_tokenizer): + """Test _generate_response method with successful response""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "[INST] Hello [/INST]"} + + # Configure mock response + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps({"generated_text": "Hello! I'm doing well, thank you!"}) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="[INST] Hello [/INST]") + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify results + self.assertEqual(resp_text, "Hello! I'm doing well, thank you!") + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once() + custom_tgi.get_chat_formatted_text.assert_called_once() + mock_client.build_request_with_payload.assert_called_once() + mock_client.async_send_request.assert_awaited_once() + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_http_error(self, mock_set_client, mock_tokenizer, mock_logger): + """Test _generate_response method with HTTP error""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + + # Configure mock response with error + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify results - should have ERROR prefix + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertEqual(resp_status, 500) + + # Verify error logging + mock_logger.error.assert_called() + self.assertIn("HTTP request failed", str(mock_logger.error.call_args)) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_server_down( + self, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_response method with server down error""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + + # Configure mock response with server down error + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = f"{constants.ERROR_PREFIX} {constants.ELEMAI_JOB_DOWN}" + mock_response.status = 500 + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify results - status should be set to 503 + self.assertIn(constants.ELEMAI_JOB_DOWN, resp_text) + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_connection_error( + self, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_response method with connection error""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + + # Configure mock response with connection error + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = f"{constants.ERROR_PREFIX} {constants.CONNECTION_ERROR}" + mock_response.status = 500 + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify results - status should be set to 503 + self.assertIn(constants.CONNECTION_ERROR, resp_text) + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_exception(self, mock_set_client, mock_tokenizer, mock_logger): + """Test _generate_response method with exception""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request_with_payload.side_effect = Exception("Connection timeout") + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + custom_tgi._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify results + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Connection timeout", resp_text) + self.assertEqual(resp_status, 999) + + # Verify error logging + mock_logger.error.assert_called() + + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_with_extracted_status_code( + self, mock_set_client, mock_tokenizer + ): + """Test _generate_response extracts status code from error body""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request_with_payload.side_effect = Exception("Service unavailable") + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + custom_tgi._get_status_from_body = MagicMock(return_value=503) + + # Call _generate_response + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_response(self.chat_input, model_params) + + # Verify extracted status code is used + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_native_structured_output_success( + self, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_native_structured_output with successful response""" + + # Define a simple Pydantic model for testing + class TestPerson(BaseModel): + name: str + age: int + + # Setup mock client + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + + # Configure mock response with valid JSON + valid_json = '{"name": "John", "age": 30}' + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps({"generated_text": valid_json}) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_native_structured_output + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_native_structured_output( + self.chat_input, model_params, TestPerson + ) + + # Verify results + self.assertEqual(json.loads(resp_text), {"name": "John", "age": 30}) + self.assertEqual(resp_status, 200) + + # Verify schema was passed in generation params + call_args = mock_client.async_send_request.call_args + generation_params = ( + call_args.args[1] + if len(call_args.args) > 1 + else call_args.kwargs.get("generation_params") + ) + self.assertIn("parameters", generation_params) + self.assertIn("grammar", generation_params["parameters"]) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.BaseCustomModel._generate_fallback_structured_output") + @pytest.mark.asyncio + async def test_generate_native_structured_output_http_error_fallback( + self, mock_fallback, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_native_structured_output falls back on HTTP error""" + + class TestPerson(BaseModel): + name: str + age: int + + # Setup mock client with error response + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.text = "Internal Server Error" + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup fallback mock + mock_fallback.return_value = ('{"name": "Fallback", "age": 25}', 200) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_native_structured_output + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_native_structured_output( + self.chat_input, model_params, TestPerson + ) + + # Verify fallback was called + mock_fallback.assert_awaited_once() + + # Verify fallback result is returned + self.assertEqual(resp_text, '{"name": "Fallback", "age": 25}') + self.assertEqual(resp_status, 200) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.BaseCustomModel._generate_fallback_structured_output") + @pytest.mark.asyncio + async def test_generate_native_structured_output_validation_error_fallback( + self, mock_fallback, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_native_structured_output falls back on validation error""" + + class TestPerson(BaseModel): + name: str + age: int + + # Setup mock client with invalid response + mock_client = MagicMock() + mock_client.build_request_with_payload.return_value = {"inputs": "Test input"} + + # Response with invalid data (age is string instead of int) + invalid_json = '{"name": "John", "age": "thirty"}' + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = json.dumps({"generated_text": invalid_json}) + mock_client.async_send_request = AsyncMock(return_value=mock_response) + + # Setup fallback mock + mock_fallback.return_value = ('{"name": "Fallback", "age": 25}', 200) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_native_structured_output + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_native_structured_output( + self.chat_input, model_params, TestPerson + ) + + # Verify fallback was called due to validation error + mock_fallback.assert_awaited_once() + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @patch("sygra.core.models.custom_models.BaseCustomModel._generate_fallback_structured_output") + @pytest.mark.asyncio + async def test_generate_native_structured_output_exception_fallback( + self, mock_fallback, mock_set_client, mock_tokenizer, mock_logger + ): + """Test _generate_native_structured_output falls back on exception""" + + class TestPerson(BaseModel): + name: str + age: int + + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request_with_payload.side_effect = Exception("Network error") + + # Setup fallback mock + mock_fallback.return_value = ('{"name": "Fallback", "age": 25}', 200) + + # Setup custom model + custom_tgi = CustomTGI(self.base_config) + custom_tgi._client = mock_client + custom_tgi.get_chat_formatted_text = MagicMock(return_value="Test input") + + # Call _generate_native_structured_output + model_params = ModelParams(url="http://tgi-test.com", auth_token="test_token") + resp_text, resp_status = await custom_tgi._generate_native_structured_output( + self.chat_input, model_params, TestPerson + ) + + # Verify fallback was called + mock_fallback.assert_awaited_once() + + # Verify error logging + mock_logger.error.assert_called() + self.assertIn( + "Native structured output generation failed", str(mock_logger.error.call_args) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/models/test_triton_custom_model.py b/tests/core/models/test_custom_triton.py similarity index 93% rename from tests/core/models/test_triton_custom_model.py rename to tests/core/models/test_custom_triton.py index c844070b..77f4124a 100644 --- a/tests/core/models/test_triton_custom_model.py +++ b/tests/core/models/test_custom_triton.py @@ -227,8 +227,8 @@ def test_get_response_text_parse_exception(self, mock_logger): @patch("sygra.core.models.custom_models.utils") @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") @pytest.mark.asyncio - async def test_generate_text_success(self, mock_set_client, mock_utils): - """Test _generate_text method with successful response""" + async def test_generate_response_success(self, mock_set_client, mock_utils): + """Test _generate_response method with successful response""" # Setup mock client mock_client = MagicMock() mock_client.build_request.return_value = {"payload": "test_payload"} @@ -259,9 +259,11 @@ async def test_generate_text_success(self, mock_set_client, mock_utils): custom_triton._get_payload_json_template = MagicMock(return_value=self.mock_payload_json) custom_triton._create_triton_request = MagicMock(return_value=self.mock_payload_json) - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://triton-test.com", auth_token="test_token") - resp_text, resp_status = await custom_triton._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_triton._generate_response( + self.chat_input, model_params + ) # Verify results self.assertEqual(resp_text, "Hello there!") @@ -280,8 +282,8 @@ async def test_generate_text_success(self, mock_set_client, mock_utils): @patch("sygra.core.models.custom_models.utils") @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") @pytest.mark.asyncio - async def test_generate_text_http_error(self, mock_set_client, mock_utils, mock_logger): - """Test _generate_text method with HTTP error""" + async def test_generate_response_http_error(self, mock_set_client, mock_utils, mock_logger): + """Test _generate_response method with HTTP error""" # Setup mock client mock_client = MagicMock() mock_client.build_request.return_value = {"payload": "test_payload"} @@ -311,9 +313,11 @@ async def test_generate_text_http_error(self, mock_set_client, mock_utils, mock_ custom_triton._get_payload_json_template = MagicMock(return_value=self.mock_payload_json) custom_triton._create_triton_request = MagicMock(return_value=self.mock_payload_json) - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://triton-test.com", auth_token="test_token") - resp_text, resp_status = await custom_triton._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_triton._generate_response( + self.chat_input, model_params + ) # Verify results self.assertEqual(resp_text, "") @@ -327,8 +331,8 @@ async def test_generate_text_http_error(self, mock_set_client, mock_utils, mock_ @patch("sygra.core.models.custom_models.logger") @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") @pytest.mark.asyncio - async def test_generate_text_exception(self, mock_set_client, mock_logger): - """Test _generate_text method with exception""" + async def test_generate_response_exception(self, mock_set_client, mock_logger): + """Test _generate_response method with exception""" # Setup mock client mock_client = MagicMock() mock_client.build_request.side_effect = Exception("Test error") @@ -338,9 +342,11 @@ async def test_generate_text_exception(self, mock_set_client, mock_logger): custom_triton._client = mock_client custom_triton._get_status_from_body = MagicMock(return_value=None) - # Call _generate_text + # Call _generate_response model_params = ModelParams(url="http://triton-test.com", auth_token="test_token") - resp_text, resp_status = await custom_triton._generate_text(self.chat_input, model_params) + resp_text, resp_status = await custom_triton._generate_response( + self.chat_input, model_params + ) # Verify results self.assertEqual(resp_text, "Http request failed Test error") diff --git a/tests/core/models/test_custom_vllm.py b/tests/core/models/test_custom_vllm.py new file mode 100644 index 00000000..72851dfb --- /dev/null +++ b/tests/core/models/test_custom_vllm.py @@ -0,0 +1,395 @@ +import sys +import unittest +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import openai +import pytest + +# Add the parent directory to sys.path to import the necessary modules +sys.path.append(str(Path(__file__).parent.parent.parent.parent)) + +from langchain_core.messages import HumanMessage, SystemMessage +from langchain_core.prompt_values import ChatPromptValue + +from sygra.core.models.custom_models import CustomVLLM, ModelParams +from sygra.utils import constants + + +class TestCustomVLLM(unittest.TestCase): + """Unit tests for the CustomVLLM class""" + + def setUp(self): + """Set up test fixtures before each test method""" + # Base model configuration + self.base_config = { + "name": "vllm_model", + "parameters": {"temperature": 0.7, "max_tokens": 100}, + "url": "http://vllm-test.com", + "auth_token": "Bearer test_token_123", + } + + # Configuration with completions API + self.completions_config = { + **self.base_config, + "completions_api": True, + "hf_chat_template_model_id": "meta-llama/Llama-2-7b-chat-hf", + } + + # Configuration with model serving name + self.serving_name_config = { + **self.base_config, + "model_serving_name": "custom_serving_name", + } + + # Mock messages + self.messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="Hello, how are you?"), + ] + self.chat_input = ChatPromptValue(messages=self.messages) + + def test_init(self): + """Test initialization of CustomVLLM""" + custom_vllm = CustomVLLM(self.base_config) + + # Verify model was properly initialized + self.assertEqual(custom_vllm.model_config, self.base_config) + self.assertEqual(custom_vllm.generation_params, self.base_config["parameters"]) + self.assertEqual(custom_vllm.auth_token, "test_token_123") # Bearer prefix removed + self.assertEqual(custom_vllm.name(), "vllm_model") + self.assertEqual(custom_vllm.model_serving_name, "vllm_model") # Default to name + + def test_init_with_custom_serving_name(self): + """Test initialization with custom model serving name""" + custom_vllm = CustomVLLM(self.serving_name_config) + + self.assertEqual(custom_vllm.model_serving_name, "custom_serving_name") + + @patch("sygra.core.models.custom_models.logger") + def test_init_with_completions_api(self, mock_logger): + """Test initialization with completions API enabled""" + with patch("sygra.core.models.custom_models.AutoTokenizer"): + custom_vllm = CustomVLLM(self.completions_config) + + self.assertTrue(custom_vllm.model_config.get("completions_api")) + # Should log that model supports completion API + mock_logger.info.assert_any_call("Model vllm_model supports completion API.") + + def test_init_missing_url_raises_error(self): + """Test initialization without url raises error""" + config = { + "name": "vllm_model", + "parameters": {"temperature": 0.7}, + "auth_token": "test_token", + } + + with self.assertRaises(Exception): + CustomVLLM(config) + + def test_init_missing_auth_token_raises_error(self): + """Test initialization without auth_token raises error""" + config = { + "name": "vllm_model", + "parameters": {"temperature": 0.7}, + "url": "http://vllm-test.com", + } + + with self.assertRaises(Exception): + CustomVLLM(config) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_chat_api_success(self, mock_set_client): + """Test _generate_response with chat API (non-completions)""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = { + "messages": [{"role": "user", "content": "Hello"}] + } + + # Setup mock completion response + mock_choice = MagicMock() + mock_choice.model_dump.return_value = { + "message": {"content": "Hello! I'm doing well, thank you!"} + } + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results + self.assertEqual(resp_text, "Hello! I'm doing well, thank you!") + self.assertEqual(resp_status, 200) + + # Verify method calls + mock_set_client.assert_called_once_with("http://vllm-test.com", "test_token") + mock_client.build_request.assert_called_once_with(messages=self.messages) + mock_client.send_request.assert_awaited_once() + + @patch("sygra.core.models.custom_models.AutoTokenizer") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_completions_api_success(self, mock_set_client, mock_tokenizer): + """Test _generate_response with completions API""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = {"prompt": "Formatted prompt text"} + + # Setup mock completion response for completions API + mock_choice = MagicMock() + mock_choice.model_dump.return_value = {"text": " Response text "} + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model with completions API + custom_vllm = CustomVLLM(self.completions_config) + custom_vllm._client = mock_client + custom_vllm.get_chat_formatted_text = MagicMock(return_value="Formatted prompt text") + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results (text should be stripped) + self.assertEqual(resp_text, "Response text") + self.assertEqual(resp_status, 200) + + # Verify completions API path was used + custom_vllm.get_chat_formatted_text.assert_called_once() + mock_client.build_request.assert_called_once_with(formatted_prompt="Formatted prompt text") + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_rate_limit_error(self, mock_set_client, mock_logger): + """Test _generate_response with rate limit error""" + # Setup mock client to raise RateLimitError + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock( + side_effect=openai.RateLimitError( + "Rate limit exceeded", response=MagicMock(), body=None + ) + ) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results - should return 429 for rate limit + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Http request failed", resp_text) + self.assertEqual(resp_status, 429) + + # Verify warning logging + mock_logger.warn.assert_called() + self.assertIn("rate limit", str(mock_logger.warn.call_args)) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_server_down(self, mock_set_client, mock_logger): + """Test _generate_response with server down error""" + # Setup mock client to raise exception with server down message + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock( + side_effect=Exception(f"Connection failed: {constants.ELEMAI_JOB_DOWN}") + ) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + custom_vllm._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results - should return 503 for server down + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn(constants.ELEMAI_JOB_DOWN, resp_text) + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_connection_error(self, mock_set_client, mock_logger): + """Test _generate_response with connection error""" + # Setup mock client to raise exception with connection error + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock(side_effect=Exception(f"{constants.CONNECTION_ERROR}")) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + custom_vllm._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results - should return 503 for connection error + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_generic_exception(self, mock_set_client, mock_logger): + """Test _generate_response with generic exception""" + # Setup mock client to raise generic exception + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock(side_effect=Exception("Network timeout")) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + custom_vllm._get_status_from_body = MagicMock(return_value=None) + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify results - should return 999 for generic error + self.assertIn(constants.ERROR_PREFIX, resp_text) + self.assertIn("Network timeout", resp_text) + self.assertEqual(resp_status, 999) + + # Verify error logging + mock_logger.error.assert_called() + + @patch("sygra.core.models.custom_models.logger") + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_with_extracted_status_code(self, mock_set_client, mock_logger): + """Test _generate_response extracts status code from error body""" + # Setup mock client to raise exception + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + mock_client.send_request = AsyncMock(side_effect=Exception("Service unavailable")) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + custom_vllm._get_status_from_body = MagicMock(return_value=503) + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + resp_text, resp_status = await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify extracted status code is used + self.assertEqual(resp_status, 503) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_with_custom_serving_name(self, mock_set_client): + """Test _generate_response uses custom serving name""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + + mock_choice = MagicMock() + mock_choice.model_dump.return_value = {"message": {"content": "Response"}} + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model with serving name + custom_vllm = CustomVLLM(self.serving_name_config) + custom_vllm._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify custom serving name was used + call_args = mock_client.send_request.call_args + self.assertEqual(call_args.args[1], "custom_serving_name") + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_generate_response_passes_generation_params(self, mock_set_client): + """Test _generate_response passes generation parameters correctly""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + + mock_choice = MagicMock() + mock_choice.model_dump.return_value = {"message": {"content": "Response"}} + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model with specific generation params + config = { + **self.base_config, + "parameters": { + "temperature": 0.9, + "max_tokens": 500, + "top_p": 0.95, + }, + } + custom_vllm = CustomVLLM(config) + custom_vllm._client = mock_client + + # Call _generate_response + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify generation parameters were passed + call_args = mock_client.send_request.call_args + passed_params = call_args.args[2] + self.assertEqual(passed_params["temperature"], 0.9) + self.assertEqual(passed_params["max_tokens"], 500) + self.assertEqual(passed_params["top_p"], 0.95) + + @patch("sygra.core.models.custom_models.BaseCustomModel._set_client") + @pytest.mark.asyncio + async def test_client_recreated_per_request(self, mock_set_client): + """Test that client is recreated for every request""" + # Setup mock client + mock_client = MagicMock() + mock_client.build_request.return_value = {"messages": []} + + mock_choice = MagicMock() + mock_choice.model_dump.return_value = {"message": {"content": "Response"}} + mock_completion = MagicMock() + mock_completion.choices = [mock_choice] + + mock_client.send_request = AsyncMock(return_value=mock_completion) + + # Setup custom model + custom_vllm = CustomVLLM(self.base_config) + custom_vllm._client = mock_client + + # Call _generate_response multiple times + model_params = ModelParams(url="http://vllm-test.com", auth_token="test_token") + await custom_vllm._generate_response(self.chat_input, model_params) + await custom_vllm._generate_response(self.chat_input, model_params) + await custom_vllm._generate_response(self.chat_input, model_params) + + # Verify _set_client was called each time + self.assertEqual(mock_set_client.call_count, 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/models/test_structured_output_support.py b/tests/core/models/test_structured_output_support.py index b9300eb2..9e431f6c 100644 --- a/tests/core/models/test_structured_output_support.py +++ b/tests/core/models/test_structured_output_support.py @@ -48,7 +48,7 @@ class UserSchema(BaseModel): # Test model that implements abstract method class CustomModel(BaseCustomModel): @pytest.mark.asyncio - async def _generate_text(self, input, model_params): + async def _generate_response(self, input, model_params): return "test response", 200 def _supports_native_structured_output(self): @@ -276,7 +276,7 @@ async def test_fallback_structured_output(self, mock_parser, mock_output_parser, mock_output_parser.return_value = mock_parser_instance model = CustomModel(self.test_config) - model._generate_text_with_retry = AsyncMock(return_value=(self.valid_json, 200)) + model._generate_response_with_retry = AsyncMock(return_value=(self.valid_json, 200)) # Execute resp_text, resp_status = await model._generate_fallback_structured_output( @@ -549,7 +549,7 @@ async def test_fallback_structured_output_parse_error( mock_output_parser.return_value = mock_parser_instance model = CustomModel(self.test_config) - model._generate_text_with_retry = AsyncMock(return_value=("Invalid JSON", 200)) + model._generate_response_with_retry = AsyncMock(return_value=("Invalid JSON", 200)) # Execute resp_text, resp_status = await model._generate_fallback_structured_output( @@ -630,7 +630,7 @@ async def test_ollama_chat_api_response_extraction(self, mock_parser, mock_utils @patch("sygra.utils.utils.validate_required_keys") @patch("sygra.core.models.structured_output.structured_output_config.SchemaConfigParser") - async def test_ollama_generate_text_success(self, mock_parser, mock_utils): + async def test_ollama_generate_response_success(self, mock_parser, mock_utils): """Test Ollama regular text generation success""" model = CustomOllama({**self.test_config, "url": "test", "auth_token": "test"}) @@ -641,20 +641,24 @@ async def test_ollama_generate_text_success(self, mock_parser, mock_utils): # Mock _set_client to prevent it from overwriting our mock client with patch.object(model, "_set_client"): - resp_text, resp_status = await model._generate_text(self.test_input, self.test_params) + resp_text, resp_status = await model._generate_response( + self.test_input, self.test_params + ) self.assertEqual(resp_text, "Generated text response") self.assertEqual(resp_status, 200) @patch("sygra.utils.utils.validate_required_keys") @patch("sygra.core.models.structured_output.structured_output_config.SchemaConfigParser") - async def test_ollama_generate_text_exception_handling(self, mock_parser, mock_utils): + async def test_ollama_generate_response_exception_handling(self, mock_parser, mock_utils): """Test Ollama regular text generation exception handling""" model = CustomOllama({**self.test_config, "url": "test", "auth_token": "test"}) # Mock _set_client to raise an exception with patch.object(model, "_set_client", side_effect=Exception("Connection failed")): - resp_text, resp_status = await model._generate_text(self.test_input, self.test_params) + resp_text, resp_status = await model._generate_response( + self.test_input, self.test_params + ) self.assertIn("ERROR", resp_text) self.assertIn("Connection failed", resp_text)