Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sygra/utils/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def is_data_url(val: Any) -> bool:
Returns:
bool: True if the value is a data URL, False otherwise.
"""
return isinstance(val, str) and val.startswith("data:")
return isinstance(val, str) and val.startswith("data:audio/")


def is_hf_audio_dict(val: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion sygra/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
COMPLETION_ONLY_MODELS: list[str] = []

# constants for template based payload
PAYLOAD_CFG_FILE = "config/payload_cfg.json"
PAYLOAD_CFG_FILE = "sygra/config/payload_cfg.json"
PAYLOAD_JSON = "payload_json"
TEST_PAYLOAD = "test_payload"
RESPONSE_KEY = "response_key"
Expand Down
39 changes: 23 additions & 16 deletions tests/core/models/client/test_http_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import sys
import unittest
from pathlib import Path
Expand Down Expand Up @@ -160,37 +161,40 @@ def test_send_request_exception_handling(self, mock_request):
# Verify empty response is returned on exception
self.assertEqual(response, "")

@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_async_send_request(self, mock_post):
def test_async_send_request(self, mock_post):
asyncio.run(self._run_async_send_request(mock_post))

async def _run_async_send_request(self, mock_post):
"""Test async_send_request method sends HTTP requests correctly"""
# Setup mock response

# --- Setup mock response ---
mock_response = AsyncMock()
mock_response.text = AsyncMock(return_value='{"result": "Success"}')
mock_response.status = 200
mock_response.__aenter__.return_value = mock_response
mock_response.headers = {"x-test": "true"}
mock_response.__aenter__.return_value = mock_response # async context manager
mock_response.__aexit__.return_value = None

mock_post.return_value = mock_response

# Test basic async request
# --- Run the actual client method ---
payload = {"prompt": "Test prompt"}
response = await self.client.async_send_request(payload)
await self.client.async_send_request(payload)

# Verify request was made with correct parameters
# --- Verify the request ---
mock_post.assert_called_once()
call_kwargs = mock_post.call_args.kwargs
self.assertEqual(call_kwargs["url"], self.base_url)
self.assertEqual(call_kwargs["headers"], self.headers)
self.assertEqual(call_kwargs["timeout"], 30)
self.assertEqual(call_kwargs["ssl"], False)
self.assertEqual(call_kwargs["ssl"], True)
self.assertEqual(json.loads(call_kwargs["data"].decode()), payload)

# Check that response is properly processed
self.assertEqual(response.text, '{"result": "Success"}')
self.assertEqual(response.status_code, 200)

@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_async_send_request_with_generation_params(self, mock_post):
def test_async_send_request_with_generation_params(self, mock_post):
asyncio.run(self._run_async_send_request(mock_post))

async def _run_async_send_request_with_generation_params(self, mock_post):
"""Test async_send_request with generation parameters"""
# Setup mock response
mock_response = AsyncMock()
Expand All @@ -216,7 +220,10 @@ async def test_async_send_request_with_generation_params(self, mock_post):

@pytest.mark.asyncio
@patch("aiohttp.ClientSession.post")
async def test_async_send_request_exception_handling(self, mock_post):
def test_async_send_request_exception_handling(self, mock_post):
asyncio.run(self._run_async_send_request_exception_handling(mock_post))

async def _run_async_send_request_exception_handling(self, mock_post):
"""Test async_send_request handles exceptions correctly"""
# Setup mock to raise exception
mock_post.side_effect = Exception("Network error")
Expand Down
108 changes: 62 additions & 46 deletions tests/core/models/langgraph/test_openai_chat_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import asyncio
import sys
import unittest
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch

import openai
import pytest
from langchain_core.messages import HumanMessage

sys.path.append(str(Path(__file__).parent.parent.parent.parent.parent))
Expand Down Expand Up @@ -32,8 +32,10 @@ def setUp(self):
"name": "test_openai_model",
"model_type": "azure_openai",
"url": "https://test-openai-endpoint.com",
"api_key": "test_api_key",
"auth_token": "test_key_123",
"parameters": {"temperature": 0.7, "max_tokens": 500},
"model": "gpt-4o",
"api_version": "2023-05-15",
}

# Save original constants
Expand All @@ -44,26 +46,32 @@ def tearDown(self):
# Restore original constants
constants.ERROR_PREFIX = self.original_error_prefix

@patch("sygra.core.models.custom_models.BaseCustomModel._set_client")
@patch("sygra.core.models.langgraph.openai_chat_model.logger")
@pytest.mark.asyncio
async def test_generate_response_success(self, mock_logger):
def test_generate_response_success(self, mock_logger, mock_set_client):
asyncio.run(self._run_generate_response_success(mock_logger, mock_set_client))

async def _run_generate_response_success(self, mock_logger, mock_set_client):
"""Test successful response generation."""
model = CustomOpenAIChatModel(self.base_config)

# Mock the client
# Setup mock client
mock_client = MagicMock()
model._client = mock_client
mock_client.build_request.return_value = {
"messages": [{"role": "user", "content": "Hello"}]
}

# Set up the mock client's build_request and send_request methods
mock_client.build_request = MagicMock(
return_value={"messages": [{"role": "user", "content": "Hello"}]}
)
mock_client.send_request = AsyncMock(
return_value={
"id": "test-id",
"choices": [{"message": {"content": "Test response"}}],
}
)
# Setup mock completion response
mock_choice = MagicMock()
mock_choice.model_dump.return_value = {
"message": {"content": "Hello! I'm doing well, thank you!"}
}
mock_completion = MagicMock()
mock_completion.choices = [mock_choice]

mock_client.send_request = AsyncMock(return_value=mock_completion)

model = CustomOpenAIChatModel(self.base_config)
model._client = mock_client

# Create test messages
messages = [HumanMessage(content="Hello")]
Expand All @@ -77,24 +85,26 @@ async def test_generate_response_success(self, mock_logger):
# Verify the response
self.assertEqual(status_code, 200)
self.assertEqual(
response,
{"id": "test-id", "choices": [{"message": {"content": "Test response"}}]},
response.choices[0].model_dump.return_value["message"]["content"],
"Hello! I'm doing well, thank you!",
)

# Verify the client methods were called correctly
mock_client.build_request.assert_called_once_with(messages=messages)
mock_client.send_request.assert_called_once_with(
{"messages": [{"role": "user", "content": "Hello"}]},
self.base_config.get("model"),
"gpt-4o",
self.base_config.get("parameters"),
)

# Verify no errors were logged
mock_logger.error.assert_not_called()

@patch("sygra.core.models.langgraph.openai_chat_model.logger")
@pytest.mark.asyncio
async def test_generate_response_rate_limit_error(self, mock_logger):
def test_generate_response_rate_limit_error(self, mock_logger):
asyncio.run(self._run_generate_response_rate_limit_error(mock_logger))

async def _run_generate_response_rate_limit_error(self, mock_logger):
"""Test handling of rate limit errors."""
model = CustomOpenAIChatModel(self.base_config)

Expand Down Expand Up @@ -127,7 +137,8 @@ async def test_generate_response_rate_limit_error(self, mock_logger):
# Patch _set_client to avoid actual client creation through ClientFactory
with patch.object(model, "_set_client"):
# Call the method
response, status_code = await model._generate_response(messages)
model_params = ModelParams(url="http://test-url", auth_token="test-token")
response, status_code = await model._generate_response(messages, model_params)

# Verify the response
self.assertEqual(status_code, 429)
Expand All @@ -140,8 +151,10 @@ async def test_generate_response_rate_limit_error(self, mock_logger):
self.assertIn("exceeded rate limit", warn_message)

@patch("sygra.core.models.langgraph.openai_chat_model.logger")
@pytest.mark.asyncio
async def test_generate_response_generic_exception(self, mock_logger):
def test_generate_response_generic_exception(self, mock_logger):
asyncio.run(self._run_generate_response_generic_exception(mock_logger))

async def _run_generate_response_generic_exception(self, mock_logger):
"""Test handling of generic exceptions."""
model = CustomOpenAIChatModel(self.base_config)

Expand All @@ -167,7 +180,8 @@ async def test_generate_response_generic_exception(self, mock_logger):
# Patch _set_client to avoid actual client creation through ClientFactory
with patch.object(model, "_set_client"):
# Call the method
response, status_code = await model._generate_response(messages)
model_params = ModelParams(url="http://test-url", auth_token="test-token")
response, status_code = await model._generate_response(messages, model_params)

# Verify the response
self.assertEqual(status_code, 500)
Expand All @@ -180,8 +194,10 @@ async def test_generate_response_generic_exception(self, mock_logger):
self.assertIn("Http request failed", error_message)

@patch("sygra.core.models.langgraph.openai_chat_model.logger")
@pytest.mark.asyncio
async def test_generate_response_status_not_found(self, mock_logger):
def test_generate_response_status_not_found(self, mock_logger):
asyncio.run(self._run_generate_response_status_not_found(mock_logger))

async def _run_generate_response_status_not_found(self, mock_logger):
"""Test handling of exceptions where status code cannot be extracted."""
model = CustomOpenAIChatModel(self.base_config)

Expand All @@ -207,7 +223,8 @@ async def test_generate_response_status_not_found(self, mock_logger):
# Patch _set_client to avoid actual client creation through ClientFactory
with patch.object(model, "_set_client"):
# Call the method
response, status_code = await model._generate_response(messages)
model_params = ModelParams(url="http://test-url", auth_token="test-token")
response, status_code = await model._generate_response(messages, model_params)

# Verify the response - should use default status code 999
self.assertEqual(status_code, 999)
Expand All @@ -219,8 +236,10 @@ async def test_generate_response_status_not_found(self, mock_logger):

@patch("sygra.core.models.langgraph.openai_chat_model.logger")
@patch("sygra.core.models.langgraph.openai_chat_model.SygraBaseChatModel._set_client")
@pytest.mark.asyncio
async def test_generate_response_with_client_factory(self, mock_set_client, mock_logger):
def test_generate_response_with_client_factory(self, mock_set_client, mock_logger):
asyncio.run(self._run_generate_response_with_client_factory(mock_set_client, mock_logger))

async def _run_generate_response_with_client_factory(self, mock_set_client, mock_logger):
"""
Test response generation with proper _set_client integration.

Expand All @@ -243,7 +262,7 @@ async def test_generate_response_with_client_factory(self, mock_set_client, mock
)

# Have _set_client correctly set the client
def mock_set_client_implementation(async_client=True):
def mock_set_client_implementation(url="http://test-url", auth_token="test-token"):
model._client = mock_client

mock_set_client.side_effect = mock_set_client_implementation
Expand All @@ -252,21 +271,17 @@ def mock_set_client_implementation(async_client=True):
messages = [HumanMessage(content="Hello")]

# Call the method
response, status_code = await model._generate_response(messages)
model_params = ModelParams(url="http://test-url", auth_token="test-token")
response, status_code = await model._generate_response(messages, model_params)

# Verify _set_client was called
mock_set_client.assert_called_once()

# Verify the response
self.assertEqual(status_code, 200)
self.assertEqual(
response,
{"id": "test-id", "choices": [{"message": {"content": "Test response"}}]},
)

@patch("sygra.core.models.langgraph.vllm_chat_model.logger")
@pytest.mark.asyncio
async def test_generate_response_with_additional_kwargs(self, mock_logger):
def test_generate_response_with_additional_kwargs(self, mock_logger):
asyncio.run(self._run_generate_response_with_additional_kwargs(mock_logger))

async def _run_generate_response_with_additional_kwargs(self, mock_logger):
"""Test synchronous response generation with additional kwargs passed."""
model = CustomOpenAIChatModel(self.base_config)

Expand All @@ -278,7 +293,7 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):
mock_client.build_request = MagicMock(
return_value={"messages": [{"role": "user", "content": "Hello"}]}
)
mock_client.send_request = MagicMock(
mock_client.send_request = AsyncMock(
return_value={
"id": "test-id",
"choices": [{"message": {"content": "Test response"}}],
Expand All @@ -290,7 +305,6 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):

# Additional kwargs to pass
additional_kwargs = {
"stream": True,
"tools": [
{
"type": "function",
Expand All @@ -305,15 +319,17 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):
# Patch _set_client to avoid actual client creation through ClientFactory
with patch.object(model, "_set_client"):
# Call the method with additional kwargs
response, status_code = await model._generate_response(messages, **additional_kwargs)
model_params = ModelParams(url="http://test-url", auth_token="test-token")
response, status_code = await model._generate_response(
messages, model_params, **additional_kwargs
)

# Verify the response
self.assertEqual(status_code, 200)

# Verify the client methods were called correctly with the additional kwargs
mock_client.build_request.assert_called_once_with(
messages=messages,
stream=True,
tools=[
{
"type": "function",
Expand Down
Loading