@@ -18,7 +18,12 @@ def getenv_side_effect(key):
1818
1919class TestOpenAIClient (unittest .IsolatedAsyncioTestCase ):
2020 def setUp (self ):
21- self .client = OpenAIClient (api_key = "fake_api_key" )
21+ # Default client for LLM tests
22+ self .client = OpenAIClient (api_key = "fake_api_key" , model_type = ModelType .LLM )
23+
24+ # Client for image generation tests
25+ self .image_client = OpenAIClient (api_key = "fake_api_key" , model_type = ModelType .IMAGE_GENERATION )
26+
2227 self .mock_response = {
2328 "id" : "cmpl-3Q8Z5J9Z1Z5z5" ,
2429 "created" : 1635820005 ,
@@ -152,7 +157,6 @@ def test_convert_inputs_to_api_kwargs_with_images(self):
152157 result = self .client .convert_inputs_to_api_kwargs (
153158 input = "Describe this image" ,
154159 model_kwargs = model_kwargs ,
155- model_type = ModelType .LLM ,
156160 )
157161 expected_content = [
158162 {"type" : "text" , "text" : "Describe this image" },
@@ -175,7 +179,6 @@ def test_convert_inputs_to_api_kwargs_with_images(self):
175179 result = self .client .convert_inputs_to_api_kwargs (
176180 input = "Compare these images" ,
177181 model_kwargs = model_kwargs ,
178- model_type = ModelType .LLM ,
179182 )
180183 expected_content = [
181184 {"type" : "text" , "text" : "Compare these images" },
@@ -202,15 +205,13 @@ async def test_acall_llm(self, MockAsyncOpenAI):
202205 MockAsyncOpenAI .return_value = mock_async_client
203206
204207 # Mock the response
205-
206208 mock_async_client .chat .completions .create = AsyncMock (
207209 return_value = self .mock_response
208210 )
209211
210212 # Call the _acall method
211-
212213 result = await self .client .acall (
213- api_kwargs = self .api_kwargs , model_type = ModelType . LLM
214+ api_kwargs = self .api_kwargs ,
214215 )
215216
216217 # Assertions
@@ -236,7 +237,7 @@ def test_call(self, MockSyncOpenAI, mock_init_sync_client):
236237 self .client .sync_client = mock_sync_client
237238
238239 # Call the call method
239- result = self .client .call (api_kwargs = self .api_kwargs , model_type = ModelType . LLM )
240+ result = self .client .call (api_kwargs = self .api_kwargs )
240241
241242 # Assertions
242243 mock_sync_client .chat .completions .create .assert_called_once_with (
@@ -264,7 +265,7 @@ async def test_acall_llm_with_vision(self, MockAsyncOpenAI):
264265
265266 # Call the _acall method with vision model
266267 result = await self .client .acall (
267- api_kwargs = self .vision_api_kwargs , model_type = ModelType . LLM
268+ api_kwargs = self .vision_api_kwargs ,
268269 )
269270
270271 # Assertions
@@ -293,7 +294,7 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
293294
294295 # Call the call method with vision model
295296 result = self .client .call (
296- api_kwargs = self .vision_api_kwargs , model_type = ModelType . LLM
297+ api_kwargs = self .vision_api_kwargs ,
297298 )
298299
299300 # Assertions
@@ -314,10 +315,9 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
314315
315316 def test_convert_inputs_to_api_kwargs_for_image_generation (self ):
316317 # Test basic image generation
317- result = self .client .convert_inputs_to_api_kwargs (
318+ result = self .image_client .convert_inputs_to_api_kwargs (
318319 input = "a white siamese cat" ,
319320 model_kwargs = {"model" : "dall-e-3" },
320- model_type = ModelType .IMAGE_GENERATION ,
321321 )
322322 self .assertEqual (result ["prompt" ], "a white siamese cat" )
323323 self .assertEqual (result ["model" ], "dall-e-3" )
@@ -335,14 +335,13 @@ def test_convert_inputs_to_api_kwargs_for_image_generation(self):
335335 with open (test_mask , "wb" ) as f :
336336 f .write (b"fake mask content" )
337337
338- result = self .client .convert_inputs_to_api_kwargs (
338+ result = self .image_client .convert_inputs_to_api_kwargs (
339339 input = "a white siamese cat" ,
340340 model_kwargs = {
341341 "model" : "dall-e-2" ,
342342 "image" : test_image ,
343343 "mask" : test_mask ,
344344 },
345- model_type = ModelType .IMAGE_GENERATION ,
346345 )
347346 self .assertEqual (result ["prompt" ], "a white siamese cat" )
348347 self .assertEqual (result ["model" ], "dall-e-2" )
@@ -366,9 +365,8 @@ async def test_acall_image_generation(self, MockAsyncOpenAI):
366365 )
367366
368367 # Call the acall method with image generation
369- result = await self .client .acall (
368+ result = await self .image_client .acall (
370369 api_kwargs = self .image_generation_kwargs ,
371- model_type = ModelType .IMAGE_GENERATION ,
372370 )
373371
374372 # Assertions
@@ -379,7 +377,7 @@ async def test_acall_image_generation(self, MockAsyncOpenAI):
379377 self .assertEqual (result , self .mock_image_response )
380378
381379 # Test parse_image_generation_response
382- output = self .client .parse_image_generation_response (result )
380+ output = self .image_client .parse_image_generation_response (result )
383381 self .assertTrue (isinstance (output , GeneratorOutput ))
384382 self .assertEqual (output .data , "https://example.com/generated_image.jpg" )
385383
@@ -398,12 +396,11 @@ def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
398396 )
399397
400398 # Set the sync client
401- self .client .sync_client = mock_sync_client
399+ self .image_client .sync_client = mock_sync_client
402400
403401 # Call the call method with image generation
404- result = self .client .call (
402+ result = self .image_client .call (
405403 api_kwargs = self .image_generation_kwargs ,
406- model_type = ModelType .IMAGE_GENERATION ,
407404 )
408405
409406 # Assertions
@@ -413,7 +410,7 @@ def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
413410 self .assertEqual (result , self .mock_image_response )
414411
415412 # Test parse_image_generation_response
416- output = self .client .parse_image_generation_response (result )
413+ output = self .image_client .parse_image_generation_response (result )
417414 self .assertTrue (isinstance (output , GeneratorOutput ))
418415 self .assertEqual (output .data , "https://example.com/generated_image.jpg" )
419416
0 commit comments