Skip to content

Commit ff1060a

Browse files
committed
add image gen
1 parent 852c212 commit ff1060a

File tree

3 files changed

+204
-3
lines changed

3 files changed

+204
-3
lines changed

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from openai.types import (
3737
Completion,
3838
CreateEmbeddingResponse,
39+
Image,
3940
)
4041
from openai.types.chat import ChatCompletionChunk, ChatCompletion
4142

@@ -114,6 +115,14 @@ class OpenAIClient(ModelClient):
114115
For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them.
115116
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini).
116117
118+
For image generation, use model_type=ModelType.IMAGE_GENERATION and provide:
119+
- model: "dall-e-3" or "dall-e-2"
120+
- prompt: Text description of the image to generate
121+
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
122+
- quality: "standard" or "hd" (DALL-E 3 only)
123+
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
124+
- response_format: "url" or "b64_json"
125+
117126
Args:
118127
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
119128
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
@@ -123,6 +132,7 @@ class OpenAIClient(ModelClient):
123132
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
124133
- Chat models: https://platform.openai.com/docs/guides/text-generation
125134
- Vision models: https://platform.openai.com/docs/guides/vision
135+
- Image models: https://platform.openai.com/docs/guides/images
126136
- OpenAI docs: https://platform.openai.com/docs/introduction
127137
"""
128138

@@ -292,10 +302,54 @@ def convert_inputs_to_api_kwargs(
292302
else:
293303
messages.append({"role": "system", "content": input})
294304
final_model_kwargs["messages"] = messages
305+
elif model_type == ModelType.IMAGE_GENERATION:
306+
# For image generation, input is the prompt
307+
final_model_kwargs["prompt"] = input
308+
# Set defaults for DALL-E 3 if not specified
309+
if "model" not in final_model_kwargs:
310+
final_model_kwargs["model"] = "dall-e-3"
311+
if "size" not in final_model_kwargs:
312+
final_model_kwargs["size"] = "1024x1024"
313+
if "quality" not in final_model_kwargs:
314+
final_model_kwargs["quality"] = "standard"
315+
if "n" not in final_model_kwargs:
316+
final_model_kwargs["n"] = 1
317+
if "response_format" not in final_model_kwargs:
318+
final_model_kwargs["response_format"] = "url"
319+
320+
# Handle image edits and variations
321+
if "image" in final_model_kwargs:
322+
if isinstance(final_model_kwargs["image"], str):
323+
# If it's a file path, encode it
324+
if os.path.isfile(final_model_kwargs["image"]):
325+
final_model_kwargs["image"] = self._encode_image(final_model_kwargs["image"])
326+
if "mask" in final_model_kwargs and isinstance(final_model_kwargs["mask"], str):
327+
if os.path.isfile(final_model_kwargs["mask"]):
328+
final_model_kwargs["mask"] = self._encode_image(final_model_kwargs["mask"])
295329
else:
296330
raise ValueError(f"model_type {model_type} is not supported")
297331
return final_model_kwargs
298332

333+
def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
334+
"""Parse the image generation response into a GeneratorOutput."""
335+
try:
336+
# Extract URLs or base64 data from the response
337+
data = [img.url or img.b64_json for img in response]
338+
# For single image responses, unwrap from list
339+
if len(data) == 1:
340+
data = data[0]
341+
return GeneratorOutput(
342+
data=data,
343+
raw_response=str(response),
344+
)
345+
except Exception as e:
346+
log.error(f"Error parsing image generation response: {e}")
347+
return GeneratorOutput(
348+
data=None,
349+
error=str(e),
350+
raw_response=str(response)
351+
)
352+
299353
@backoff.on_exception(
300354
backoff.expo,
301355
(
@@ -320,6 +374,19 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
320374
self.chat_completion_parser = handle_streaming_response
321375
return self.sync_client.chat.completions.create(**api_kwargs)
322376
return self.sync_client.chat.completions.create(**api_kwargs)
377+
elif model_type == ModelType.IMAGE_GENERATION:
378+
# Determine which image API to call based on the presence of image/mask
379+
if "image" in api_kwargs:
380+
if "mask" in api_kwargs:
381+
# Image edit
382+
response = self.sync_client.images.edit(**api_kwargs)
383+
else:
384+
# Image variation
385+
response = self.sync_client.images.create_variation(**api_kwargs)
386+
else:
387+
# Image generation
388+
response = self.sync_client.images.generate(**api_kwargs)
389+
return response.data
323390
else:
324391
raise ValueError(f"model_type {model_type} is not supported")
325392

@@ -346,6 +413,19 @@ async def acall(
346413
return await self.async_client.embeddings.create(**api_kwargs)
347414
elif model_type == ModelType.LLM:
348415
return await self.async_client.chat.completions.create(**api_kwargs)
416+
elif model_type == ModelType.IMAGE_GENERATION:
417+
# Determine which image API to call based on the presence of image/mask
418+
if "image" in api_kwargs:
419+
if "mask" in api_kwargs:
420+
# Image edit
421+
response = await self.async_client.images.edit(**api_kwargs)
422+
else:
423+
# Image variation
424+
response = await self.async_client.images.create_variation(**api_kwargs)
425+
else:
426+
# Image generation
427+
response = await self.async_client.images.generate(**api_kwargs)
428+
return response.data
349429
else:
350430
raise ValueError(f"model_type {model_type} is not supported")
351431

adalflow/adalflow/core/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ModelType(Enum):
5858
EMBEDDER = auto()
5959
LLM = auto()
6060
RERANKER = auto() # ranking model
61+
IMAGE_GENERATION = auto() # image generation models like DALL-E
6162
UNDEFINED = auto()
6263

6364

adalflow/tests/test_openai_client.py

Lines changed: 123 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import base64
55

6-
from openai.types import CompletionUsage
6+
from openai.types import CompletionUsage, Image
77
from openai.types.chat import ChatCompletion
88

99
from adalflow.core.types import ModelType, GeneratorOutput
@@ -23,7 +23,7 @@ def setUp(self):
2323
"id": "cmpl-3Q8Z5J9Z1Z5z5",
2424
"created": 1635820005,
2525
"object": "chat.completion",
26-
"model": "gpt-3.5-turbo",
26+
"model": "gpt-4o",
2727
"choices": [
2828
{
2929
"message": {
@@ -59,9 +59,17 @@ def setUp(self):
5959
),
6060
}
6161
self.mock_vision_response = ChatCompletion(**self.mock_vision_response)
62+
self.mock_image_response = [
63+
Image(
64+
url="https://example.com/generated_image.jpg",
65+
b64_json=None,
66+
revised_prompt="A white siamese cat sitting elegantly",
67+
model="dall-e-3",
68+
)
69+
]
6270
self.api_kwargs = {
6371
"messages": [{"role": "user", "content": "Hello"}],
64-
"model": "gpt-3.5-turbo",
72+
"model": "gpt-4o",
6573
}
6674
self.vision_api_kwargs = {
6775
"messages": [
@@ -81,6 +89,13 @@ def setUp(self):
8189
],
8290
"model": "gpt-4o",
8391
}
92+
self.image_generation_kwargs = {
93+
"model": "dall-e-3",
94+
"prompt": "a white siamese cat",
95+
"size": "1024x1024",
96+
"quality": "standard",
97+
"n": 1,
98+
}
8499

85100
def test_encode_image(self):
86101
# Create a temporary test image file
@@ -297,6 +312,111 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
297312
self.assertEqual(output.usage.prompt_tokens, 25)
298313
self.assertEqual(output.usage.total_tokens, 40)
299314

315+
def test_convert_inputs_to_api_kwargs_for_image_generation(self):
316+
# Test basic image generation
317+
result = self.client.convert_inputs_to_api_kwargs(
318+
input="a white siamese cat",
319+
model_kwargs={"model": "dall-e-3"},
320+
model_type=ModelType.IMAGE_GENERATION,
321+
)
322+
self.assertEqual(result["prompt"], "a white siamese cat")
323+
self.assertEqual(result["model"], "dall-e-3")
324+
self.assertEqual(result["size"], "1024x1024") # default
325+
self.assertEqual(result["quality"], "standard") # default
326+
self.assertEqual(result["n"], 1) # default
327+
328+
# Test image edit
329+
test_image = "test_image.jpg"
330+
test_mask = "test_mask.jpg"
331+
try:
332+
# Create test files
333+
with open(test_image, "wb") as f:
334+
f.write(b"fake image content")
335+
with open(test_mask, "wb") as f:
336+
f.write(b"fake mask content")
337+
338+
result = self.client.convert_inputs_to_api_kwargs(
339+
input="a white siamese cat",
340+
model_kwargs={
341+
"model": "dall-e-2",
342+
"image": test_image,
343+
"mask": test_mask,
344+
},
345+
model_type=ModelType.IMAGE_GENERATION,
346+
)
347+
self.assertEqual(result["prompt"], "a white siamese cat")
348+
self.assertEqual(result["model"], "dall-e-2")
349+
self.assertTrue(isinstance(result["image"], str)) # base64 encoded
350+
self.assertTrue(isinstance(result["mask"], str)) # base64 encoded
351+
finally:
352+
# Cleanup
353+
if os.path.exists(test_image):
354+
os.remove(test_image)
355+
if os.path.exists(test_mask):
356+
os.remove(test_mask)
357+
358+
@patch("adalflow.components.model_client.openai_client.AsyncOpenAI")
359+
async def test_acall_image_generation(self, MockAsyncOpenAI):
360+
mock_async_client = AsyncMock()
361+
MockAsyncOpenAI.return_value = mock_async_client
362+
363+
# Mock the image generation response
364+
mock_async_client.images.generate = AsyncMock(
365+
return_value=type('Response', (), {'data': self.mock_image_response})()
366+
)
367+
368+
# Call the acall method with image generation
369+
result = await self.client.acall(
370+
api_kwargs=self.image_generation_kwargs,
371+
model_type=ModelType.IMAGE_GENERATION,
372+
)
373+
374+
# Assertions
375+
MockAsyncOpenAI.assert_called_once()
376+
mock_async_client.images.generate.assert_awaited_once_with(
377+
**self.image_generation_kwargs
378+
)
379+
self.assertEqual(result, self.mock_image_response)
380+
381+
# Test parse_image_generation_response
382+
output = self.client.parse_image_generation_response(result)
383+
self.assertTrue(isinstance(output, GeneratorOutput))
384+
self.assertEqual(output.data, "https://example.com/generated_image.jpg")
385+
386+
@patch(
387+
"adalflow.components.model_client.openai_client.OpenAIClient.init_sync_client"
388+
)
389+
@patch("adalflow.components.model_client.openai_client.OpenAI")
390+
def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
391+
mock_sync_client = Mock()
392+
MockSyncOpenAI.return_value = mock_sync_client
393+
mock_init_sync_client.return_value = mock_sync_client
394+
395+
# Mock the image generation response
396+
mock_sync_client.images.generate = Mock(
397+
return_value=type('Response', (), {'data': self.mock_image_response})()
398+
)
399+
400+
# Set the sync client
401+
self.client.sync_client = mock_sync_client
402+
403+
# Call the call method with image generation
404+
result = self.client.call(
405+
api_kwargs=self.image_generation_kwargs,
406+
model_type=ModelType.IMAGE_GENERATION,
407+
)
408+
409+
# Assertions
410+
mock_sync_client.images.generate.assert_called_once_with(
411+
**self.image_generation_kwargs
412+
)
413+
self.assertEqual(result, self.mock_image_response)
414+
415+
# Test parse_image_generation_response
416+
output = self.client.parse_image_generation_response(result)
417+
self.assertTrue(isinstance(output, GeneratorOutput))
418+
self.assertEqual(output.data, "https://example.com/generated_image.jpg")
419+
300420

301421
if __name__ == "__main__":
302422
unittest.main()

0 commit comments

Comments
 (0)