33import os
44import base64
55
6- from openai .types import CompletionUsage
6+ from openai .types import CompletionUsage , Image
77from openai .types .chat import ChatCompletion
88
99from adalflow .core .types import ModelType , GeneratorOutput
@@ -23,7 +23,7 @@ def setUp(self):
2323 "id" : "cmpl-3Q8Z5J9Z1Z5z5" ,
2424 "created" : 1635820005 ,
2525 "object" : "chat.completion" ,
26- "model" : "gpt-3.5-turbo " ,
26+ "model" : "gpt-4o " ,
2727 "choices" : [
2828 {
2929 "message" : {
@@ -59,9 +59,17 @@ def setUp(self):
5959 ),
6060 }
6161 self .mock_vision_response = ChatCompletion (** self .mock_vision_response )
62+ self .mock_image_response = [
63+ Image (
64+ url = "https://example.com/generated_image.jpg" ,
65+ b64_json = None ,
66+ revised_prompt = "A white siamese cat sitting elegantly" ,
67+ model = "dall-e-3" ,
68+ )
69+ ]
6270 self .api_kwargs = {
6371 "messages" : [{"role" : "user" , "content" : "Hello" }],
64- "model" : "gpt-3.5-turbo " ,
72+ "model" : "gpt-4o " ,
6573 }
6674 self .vision_api_kwargs = {
6775 "messages" : [
@@ -81,6 +89,13 @@ def setUp(self):
8189 ],
8290 "model" : "gpt-4o" ,
8391 }
92+ self .image_generation_kwargs = {
93+ "model" : "dall-e-3" ,
94+ "prompt" : "a white siamese cat" ,
95+ "size" : "1024x1024" ,
96+ "quality" : "standard" ,
97+ "n" : 1 ,
98+ }
8499
85100 def test_encode_image (self ):
86101 # Create a temporary test image file
@@ -297,6 +312,111 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
297312 self .assertEqual (output .usage .prompt_tokens , 25 )
298313 self .assertEqual (output .usage .total_tokens , 40 )
299314
315+ def test_convert_inputs_to_api_kwargs_for_image_generation (self ):
316+ # Test basic image generation
317+ result = self .client .convert_inputs_to_api_kwargs (
318+ input = "a white siamese cat" ,
319+ model_kwargs = {"model" : "dall-e-3" },
320+ model_type = ModelType .IMAGE_GENERATION ,
321+ )
322+ self .assertEqual (result ["prompt" ], "a white siamese cat" )
323+ self .assertEqual (result ["model" ], "dall-e-3" )
324+ self .assertEqual (result ["size" ], "1024x1024" ) # default
325+ self .assertEqual (result ["quality" ], "standard" ) # default
326+ self .assertEqual (result ["n" ], 1 ) # default
327+
328+ # Test image edit
329+ test_image = "test_image.jpg"
330+ test_mask = "test_mask.jpg"
331+ try :
332+ # Create test files
333+ with open (test_image , "wb" ) as f :
334+ f .write (b"fake image content" )
335+ with open (test_mask , "wb" ) as f :
336+ f .write (b"fake mask content" )
337+
338+ result = self .client .convert_inputs_to_api_kwargs (
339+ input = "a white siamese cat" ,
340+ model_kwargs = {
341+ "model" : "dall-e-2" ,
342+ "image" : test_image ,
343+ "mask" : test_mask ,
344+ },
345+ model_type = ModelType .IMAGE_GENERATION ,
346+ )
347+ self .assertEqual (result ["prompt" ], "a white siamese cat" )
348+ self .assertEqual (result ["model" ], "dall-e-2" )
349+ self .assertTrue (isinstance (result ["image" ], str )) # base64 encoded
350+ self .assertTrue (isinstance (result ["mask" ], str )) # base64 encoded
351+ finally :
352+ # Cleanup
353+ if os .path .exists (test_image ):
354+ os .remove (test_image )
355+ if os .path .exists (test_mask ):
356+ os .remove (test_mask )
357+
358+ @patch ("adalflow.components.model_client.openai_client.AsyncOpenAI" )
359+ async def test_acall_image_generation (self , MockAsyncOpenAI ):
360+ mock_async_client = AsyncMock ()
361+ MockAsyncOpenAI .return_value = mock_async_client
362+
363+ # Mock the image generation response
364+ mock_async_client .images .generate = AsyncMock (
365+ return_value = type ('Response' , (), {'data' : self .mock_image_response })()
366+ )
367+
368+ # Call the acall method with image generation
369+ result = await self .client .acall (
370+ api_kwargs = self .image_generation_kwargs ,
371+ model_type = ModelType .IMAGE_GENERATION ,
372+ )
373+
374+ # Assertions
375+ MockAsyncOpenAI .assert_called_once ()
376+ mock_async_client .images .generate .assert_awaited_once_with (
377+ ** self .image_generation_kwargs
378+ )
379+ self .assertEqual (result , self .mock_image_response )
380+
381+ # Test parse_image_generation_response
382+ output = self .client .parse_image_generation_response (result )
383+ self .assertTrue (isinstance (output , GeneratorOutput ))
384+ self .assertEqual (output .data , "https://example.com/generated_image.jpg" )
385+
386+ @patch (
387+ "adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client"
388+ )
389+ @patch ("adalflow.components.model_client.openai_client.OpenAI" )
390+ def test_call_image_generation (self , MockSyncOpenAI , mock_init_sync_client ):
391+ mock_sync_client = Mock ()
392+ MockSyncOpenAI .return_value = mock_sync_client
393+ mock_init_sync_client .return_value = mock_sync_client
394+
395+ # Mock the image generation response
396+ mock_sync_client .images .generate = Mock (
397+ return_value = type ('Response' , (), {'data' : self .mock_image_response })()
398+ )
399+
400+ # Set the sync client
401+ self .client .sync_client = mock_sync_client
402+
403+ # Call the call method with image generation
404+ result = self .client .call (
405+ api_kwargs = self .image_generation_kwargs ,
406+ model_type = ModelType .IMAGE_GENERATION ,
407+ )
408+
409+ # Assertions
410+ mock_sync_client .images .generate .assert_called_once_with (
411+ ** self .image_generation_kwargs
412+ )
413+ self .assertEqual (result , self .mock_image_response )
414+
415+ # Test parse_image_generation_response
416+ output = self .client .parse_image_generation_response (result )
417+ self .assertTrue (isinstance (output , GeneratorOutput ))
418+ self .assertEqual (output .data , "https://example.com/generated_image.jpg" )
419+
300420
301421if __name__ == "__main__" :
302422 unittest .main ()
0 commit comments