1+ import asyncio
12import sys
23import unittest
34from pathlib import Path
45from unittest .mock import AsyncMock , MagicMock , patch
56
67import openai
7- import pytest
88from langchain_core .messages import HumanMessage
99
1010sys .path .append (str (Path (__file__ ).parent .parent .parent .parent .parent ))
@@ -32,8 +32,10 @@ def setUp(self):
3232 "name" : "test_openai_model" ,
3333 "model_type" : "azure_openai" ,
3434 "url" : "https://test-openai-endpoint.com" ,
35- "api_key " : "test_api_key " ,
35+ "auth_token " : "test_key_123 " ,
3636 "parameters" : {"temperature" : 0.7 , "max_tokens" : 500 },
37+ "model" : "gpt-4o" ,
38+ "api_version" : "2023-05-15" ,
3739 }
3840
3941 # Save original constants
@@ -44,26 +46,32 @@ def tearDown(self):
4446 # Restore original constants
4547 constants .ERROR_PREFIX = self .original_error_prefix
4648
49+ @patch ("sygra.core.models.custom_models.BaseCustomModel._set_client" )
4750 @patch ("sygra.core.models.langgraph.openai_chat_model.logger" )
48- @pytest .mark .asyncio
49- async def test_generate_response_success (self , mock_logger ):
51+ def test_generate_response_success (self , mock_logger , mock_set_client ):
52+ asyncio .run (self ._run_generate_response_success (mock_logger , mock_set_client ))
53+
54+ async def _run_generate_response_success (self , mock_logger , mock_set_client ):
5055 """Test successful response generation."""
51- model = CustomOpenAIChatModel (self .base_config )
5256
53- # Mock the client
57+ # Setup mock client
5458 mock_client = MagicMock ()
55- model ._client = mock_client
59+ mock_client .build_request .return_value = {
60+ "messages" : [{"role" : "user" , "content" : "Hello" }]
61+ }
5662
57- # Set up the mock client's build_request and send_request methods
58- mock_client .build_request = MagicMock (
59- return_value = {"messages" : [{"role" : "user" , "content" : "Hello" }]}
60- )
61- mock_client .send_request = AsyncMock (
62- return_value = {
63- "id" : "test-id" ,
64- "choices" : [{"message" : {"content" : "Test response" }}],
65- }
66- )
63+ # Setup mock completion response
64+ mock_choice = MagicMock ()
65+ mock_choice .model_dump .return_value = {
66+ "message" : {"content" : "Hello! I'm doing well, thank you!" }
67+ }
68+ mock_completion = MagicMock ()
69+ mock_completion .choices = [mock_choice ]
70+
71+ mock_client .send_request = AsyncMock (return_value = mock_completion )
72+
73+ model = CustomOpenAIChatModel (self .base_config )
74+ model ._client = mock_client
6775
6876 # Create test messages
6977 messages = [HumanMessage (content = "Hello" )]
@@ -77,24 +85,26 @@ async def test_generate_response_success(self, mock_logger):
7785 # Verify the response
7886 self .assertEqual (status_code , 200 )
7987 self .assertEqual (
80- response ,
81- { "id" : "test-id" , "choices" : [{ "message" : { "content" : "Test response" }}]} ,
88+ response . choices [ 0 ]. model_dump . return_value [ "message" ][ "content" ] ,
89+ "Hello! I'm doing well, thank you!" ,
8290 )
8391
8492 # Verify the client methods were called correctly
8593 mock_client .build_request .assert_called_once_with (messages = messages )
8694 mock_client .send_request .assert_called_once_with (
8795 {"messages" : [{"role" : "user" , "content" : "Hello" }]},
88- self . base_config . get ( "model" ) ,
96+ "gpt-4o" ,
8997 self .base_config .get ("parameters" ),
9098 )
9199
92100 # Verify no errors were logged
93101 mock_logger .error .assert_not_called ()
94102
95103 @patch ("sygra.core.models.langgraph.openai_chat_model.logger" )
96- @pytest .mark .asyncio
97- async def test_generate_response_rate_limit_error (self , mock_logger ):
104+ def test_generate_response_rate_limit_error (self , mock_logger ):
105+ asyncio .run (self ._run_generate_response_rate_limit_error (mock_logger ))
106+
107+ async def _run_generate_response_rate_limit_error (self , mock_logger ):
98108 """Test handling of rate limit errors."""
99109 model = CustomOpenAIChatModel (self .base_config )
100110
@@ -127,7 +137,8 @@ async def test_generate_response_rate_limit_error(self, mock_logger):
127137 # Patch _set_client to avoid actual client creation through ClientFactory
128138 with patch .object (model , "_set_client" ):
129139 # Call the method
130- response , status_code = await model ._generate_response (messages )
140+ model_params = ModelParams (url = "http://test-url" , auth_token = "test-token" )
141+ response , status_code = await model ._generate_response (messages , model_params )
131142
132143 # Verify the response
133144 self .assertEqual (status_code , 429 )
@@ -140,8 +151,10 @@ async def test_generate_response_rate_limit_error(self, mock_logger):
140151 self .assertIn ("exceeded rate limit" , warn_message )
141152
142153 @patch ("sygra.core.models.langgraph.openai_chat_model.logger" )
143- @pytest .mark .asyncio
144- async def test_generate_response_generic_exception (self , mock_logger ):
154+ def test_generate_response_generic_exception (self , mock_logger ):
155+ asyncio .run (self ._run_generate_response_generic_exception (mock_logger ))
156+
157+ async def _run_generate_response_generic_exception (self , mock_logger ):
145158 """Test handling of generic exceptions."""
146159 model = CustomOpenAIChatModel (self .base_config )
147160
@@ -167,7 +180,8 @@ async def test_generate_response_generic_exception(self, mock_logger):
167180 # Patch _set_client to avoid actual client creation through ClientFactory
168181 with patch .object (model , "_set_client" ):
169182 # Call the method
170- response , status_code = await model ._generate_response (messages )
183+ model_params = ModelParams (url = "http://test-url" , auth_token = "test-token" )
184+ response , status_code = await model ._generate_response (messages , model_params )
171185
172186 # Verify the response
173187 self .assertEqual (status_code , 500 )
@@ -180,8 +194,10 @@ async def test_generate_response_generic_exception(self, mock_logger):
180194 self .assertIn ("Http request failed" , error_message )
181195
182196 @patch ("sygra.core.models.langgraph.openai_chat_model.logger" )
183- @pytest .mark .asyncio
184- async def test_generate_response_status_not_found (self , mock_logger ):
197+ def test_generate_response_status_not_found (self , mock_logger ):
198+ asyncio .run (self ._run_generate_response_status_not_found (mock_logger ))
199+
200+ async def _run_generate_response_status_not_found (self , mock_logger ):
185201 """Test handling of exceptions where status code cannot be extracted."""
186202 model = CustomOpenAIChatModel (self .base_config )
187203
@@ -207,7 +223,8 @@ async def test_generate_response_status_not_found(self, mock_logger):
207223 # Patch _set_client to avoid actual client creation through ClientFactory
208224 with patch .object (model , "_set_client" ):
209225 # Call the method
210- response , status_code = await model ._generate_response (messages )
226+ model_params = ModelParams (url = "http://test-url" , auth_token = "test-token" )
227+ response , status_code = await model ._generate_response (messages , model_params )
211228
212229 # Verify the response - should use default status code 999
213230 self .assertEqual (status_code , 999 )
@@ -219,8 +236,10 @@ async def test_generate_response_status_not_found(self, mock_logger):
219236
220237 @patch ("sygra.core.models.langgraph.openai_chat_model.logger" )
221238 @patch ("sygra.core.models.langgraph.openai_chat_model.SygraBaseChatModel._set_client" )
222- @pytest .mark .asyncio
223- async def test_generate_response_with_client_factory (self , mock_set_client , mock_logger ):
239+ def test_generate_response_with_client_factory (self , mock_set_client , mock_logger ):
240+ asyncio .run (self ._run_generate_response_with_client_factory (mock_set_client , mock_logger ))
241+
242+ async def _run_generate_response_with_client_factory (self , mock_set_client , mock_logger ):
224243 """
225244 Test response generation with proper _set_client integration.
226245
@@ -243,7 +262,7 @@ async def test_generate_response_with_client_factory(self, mock_set_client, mock
243262 )
244263
245264 # Have _set_client correctly set the client
246- def mock_set_client_implementation (async_client = True ):
265+ def mock_set_client_implementation (url = "http://test-url" , auth_token = "test-token" ):
247266 model ._client = mock_client
248267
249268 mock_set_client .side_effect = mock_set_client_implementation
@@ -252,21 +271,17 @@ def mock_set_client_implementation(async_client=True):
252271 messages = [HumanMessage (content = "Hello" )]
253272
254273 # Call the method
255- response , status_code = await model ._generate_response (messages )
274+ model_params = ModelParams (url = "http://test-url" , auth_token = "test-token" )
275+ response , status_code = await model ._generate_response (messages , model_params )
256276
257277 # Verify _set_client was called
258278 mock_set_client .assert_called_once ()
259279
260- # Verify the response
261- self .assertEqual (status_code , 200 )
262- self .assertEqual (
263- response ,
264- {"id" : "test-id" , "choices" : [{"message" : {"content" : "Test response" }}]},
265- )
266-
267280 @patch ("sygra.core.models.langgraph.vllm_chat_model.logger" )
268- @pytest .mark .asyncio
269- async def test_generate_response_with_additional_kwargs (self , mock_logger ):
281+ def test_generate_response_with_additional_kwargs (self , mock_logger ):
282+ asyncio .run (self ._run_generate_response_with_additional_kwargs (mock_logger ))
283+
284+ async def _run_generate_response_with_additional_kwargs (self , mock_logger ):
270285 """Test synchronous response generation with additional kwargs passed."""
271286 model = CustomOpenAIChatModel (self .base_config )
272287
@@ -278,7 +293,7 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):
278293 mock_client .build_request = MagicMock (
279294 return_value = {"messages" : [{"role" : "user" , "content" : "Hello" }]}
280295 )
281- mock_client .send_request = MagicMock (
296+ mock_client .send_request = AsyncMock (
282297 return_value = {
283298 "id" : "test-id" ,
284299 "choices" : [{"message" : {"content" : "Test response" }}],
@@ -290,7 +305,6 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):
290305
291306 # Additional kwargs to pass
292307 additional_kwargs = {
293- "stream" : True ,
294308 "tools" : [
295309 {
296310 "type" : "function" ,
@@ -305,15 +319,17 @@ async def test_generate_response_with_additional_kwargs(self, mock_logger):
305319 # Patch _set_client to avoid actual client creation through ClientFactory
306320 with patch .object (model , "_set_client" ):
307321 # Call the method with additional kwargs
308- response , status_code = await model ._generate_response (messages , ** additional_kwargs )
322+ model_params = ModelParams (url = "http://test-url" , auth_token = "test-token" )
323+ response , status_code = await model ._generate_response (
324+ messages , model_params , ** additional_kwargs
325+ )
309326
310327 # Verify the response
311328 self .assertEqual (status_code , 200 )
312329
313330 # Verify the client methods were called correctly with the additional kwargs
314331 mock_client .build_request .assert_called_once_with (
315332 messages = messages ,
316- stream = True ,
317333 tools = [
318334 {
319335 "type" : "function" ,
0 commit comments