Skip to content

Commit 98177d9

Browse files
authored
Merge pull request #313 from SylphAI-Inc/multimodal
Add support for multimodal openai - early version
2 parents 3e46f8e + d8aa41c commit 98177d9

File tree

7 files changed

+1103
-13
lines changed

7 files changed

+1103
-13
lines changed

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 173 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""OpenAI ModelClient integration."""
22

33
import os
4+
import base64
45
from typing import (
56
Dict,
67
Sequence,
@@ -35,6 +36,7 @@
3536
from openai.types import (
3637
Completion,
3738
CreateEmbeddingResponse,
39+
Image,
3840
)
3941
from openai.types.chat import ChatCompletionChunk, ChatCompletion
4042

@@ -99,7 +101,7 @@ def get_probabilities(completion: ChatCompletion) -> List[List[TokenLogProb]]:
99101
class OpenAIClient(ModelClient):
100102
__doc__ = r"""A component wrapper for the OpenAI API client.
101103
102-
Support both embedding and chat completion API.
104+
Support both embedding and chat completion API, including multimodal capabilities.
103105
104106
Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
105107
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.
@@ -110,6 +112,17 @@ class OpenAIClient(ModelClient):
110112
Instead
111113
- use :ref:`OutputParser<components-output_parsers>` for response parsing and formating.
112114
115+
For multimodal inputs, provide images in model_kwargs["images"] as a path, URL, or list of them.
116+
The model must support vision capabilities (e.g., gpt-4o, gpt-4o-mini, o1, o1-mini).
117+
118+
For image generation, use model_type=ModelType.IMAGE_GENERATION and provide:
119+
- model: "dall-e-3" or "dall-e-2"
120+
- prompt: Text description of the image to generate
121+
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
122+
- quality: "standard" or "hd" (DALL-E 3 only)
123+
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
124+
- response_format: "url" or "b64_json"
125+
113126
Args:
114127
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
115128
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
@@ -118,6 +131,8 @@ class OpenAIClient(ModelClient):
118131
References:
119132
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
120133
- Chat models: https://platform.openai.com/docs/guides/text-generation
134+
- Vision models: https://platform.openai.com/docs/guides/vision
135+
- Image models: https://platform.openai.com/docs/guides/images
121136
- OpenAI docs: https://platform.openai.com/docs/introduction
122137
"""
123138

@@ -200,7 +215,7 @@ def track_completion_usage(
200215
def parse_embedding_response(
201216
self, response: CreateEmbeddingResponse
202217
) -> EmbedderOutput:
203-
r"""Parse the embedding response to a structure LightRAG components can understand.
218+
r"""Parse the embedding response to a structure Adalflow components can understand.
204219
205220
Should be called in ``Embedder``.
206221
"""
@@ -218,7 +233,20 @@ def convert_inputs_to_api_kwargs(
218233
) -> Dict:
219234
r"""
220235
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
221-
Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format
236+
Convert the Component's standard input, and system_input(chat model) and model_kwargs into API-specific format.
237+
For multimodal inputs, images can be provided in model_kwargs["images"] as a string path, URL, or list of them.
238+
The model specified in model_kwargs["model"] must support multimodal capabilities when using images.
239+
240+
Args:
241+
input: The input text or messages to process
242+
model_kwargs: Additional parameters including:
243+
- images: Optional image source(s) as path, URL, or list of them
244+
- detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto'
245+
- model: The model to use (must support multimodal inputs if images are provided)
246+
model_type: The type of model (EMBEDDER or LLM)
247+
248+
Returns:
249+
Dict: API-specific kwargs for the model call
222250
"""
223251

224252
final_model_kwargs = model_kwargs.copy()
@@ -232,6 +260,8 @@ def convert_inputs_to_api_kwargs(
232260
elif model_type == ModelType.LLM:
233261
# convert input to messages
234262
messages: List[Dict[str, str]] = []
263+
images = final_model_kwargs.pop("images", None)
264+
detail = final_model_kwargs.pop("detail", "auto")
235265

236266
if self._input_type == "messages":
237267
system_start_tag = "<START_OF_SYSTEM_PROMPT>"
@@ -248,19 +278,74 @@ def convert_inputs_to_api_kwargs(
248278
if match:
249279
system_prompt = match.group(1)
250280
input_str = match.group(2)
251-
252281
else:
253282
print("No match found.")
254283
if system_prompt and input_str:
255284
messages.append({"role": "system", "content": system_prompt})
256-
messages.append({"role": "user", "content": input_str})
285+
if images:
286+
content = [{"type": "text", "text": input_str}]
287+
if isinstance(images, (str, dict)):
288+
images = [images]
289+
for img in images:
290+
content.append(self._prepare_image_content(img, detail))
291+
messages.append({"role": "user", "content": content})
292+
else:
293+
messages.append({"role": "user", "content": input_str})
257294
if len(messages) == 0:
258-
messages.append({"role": "system", "content": input})
295+
if images:
296+
content = [{"type": "text", "text": input}]
297+
if isinstance(images, (str, dict)):
298+
images = [images]
299+
for img in images:
300+
content.append(self._prepare_image_content(img, detail))
301+
messages.append({"role": "user", "content": content})
302+
else:
303+
messages.append({"role": "system", "content": input})
259304
final_model_kwargs["messages"] = messages
305+
elif model_type == ModelType.IMAGE_GENERATION:
306+
# For image generation, input is the prompt
307+
final_model_kwargs["prompt"] = input
308+
# Ensure model is specified
309+
if "model" not in final_model_kwargs:
310+
raise ValueError("model must be specified for image generation")
311+
# Set defaults for DALL-E 3 if not specified
312+
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
313+
final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard")
314+
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
315+
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")
316+
317+
# Handle image edits and variations
318+
image = final_model_kwargs.get("image")
319+
if isinstance(image, str) and os.path.isfile(image):
320+
final_model_kwargs["image"] = self._encode_image(image)
321+
322+
mask = final_model_kwargs.get("mask")
323+
if isinstance(mask, str) and os.path.isfile(mask):
324+
final_model_kwargs["mask"] = self._encode_image(mask)
260325
else:
261326
raise ValueError(f"model_type {model_type} is not supported")
262327
return final_model_kwargs
263328

329+
def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:
330+
"""Parse the image generation response into a GeneratorOutput."""
331+
try:
332+
# Extract URLs or base64 data from the response
333+
data = [img.url or img.b64_json for img in response]
334+
# For single image responses, unwrap from list
335+
if len(data) == 1:
336+
data = data[0]
337+
return GeneratorOutput(
338+
data=data,
339+
raw_response=str(response),
340+
)
341+
except Exception as e:
342+
log.error(f"Error parsing image generation response: {e}")
343+
return GeneratorOutput(
344+
data=None,
345+
error=str(e),
346+
raw_response=str(response)
347+
)
348+
264349
@backoff.on_exception(
265350
backoff.expo,
266351
(
@@ -285,6 +370,19 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
285370
self.chat_completion_parser = handle_streaming_response
286371
return self.sync_client.chat.completions.create(**api_kwargs)
287372
return self.sync_client.chat.completions.create(**api_kwargs)
373+
elif model_type == ModelType.IMAGE_GENERATION:
374+
# Determine which image API to call based on the presence of image/mask
375+
if "image" in api_kwargs:
376+
if "mask" in api_kwargs:
377+
# Image edit
378+
response = self.sync_client.images.edit(**api_kwargs)
379+
else:
380+
# Image variation
381+
response = self.sync_client.images.create_variation(**api_kwargs)
382+
else:
383+
# Image generation
384+
response = self.sync_client.images.generate(**api_kwargs)
385+
return response.data
288386
else:
289387
raise ValueError(f"model_type {model_type} is not supported")
290388

@@ -311,6 +409,19 @@ async def acall(
311409
return await self.async_client.embeddings.create(**api_kwargs)
312410
elif model_type == ModelType.LLM:
313411
return await self.async_client.chat.completions.create(**api_kwargs)
412+
elif model_type == ModelType.IMAGE_GENERATION:
413+
# Determine which image API to call based on the presence of image/mask
414+
if "image" in api_kwargs:
415+
if "mask" in api_kwargs:
416+
# Image edit
417+
response = await self.async_client.images.edit(**api_kwargs)
418+
else:
419+
# Image variation
420+
response = await self.async_client.images.create_variation(**api_kwargs)
421+
else:
422+
# Image generation
423+
response = await self.async_client.images.generate(**api_kwargs)
424+
return response.data
314425
else:
315426
raise ValueError(f"model_type {model_type} is not supported")
316427

@@ -332,22 +443,74 @@ def to_dict(self) -> Dict[str, Any]:
332443
output = super().to_dict(exclude=exclude)
333444
return output
334445

446+
def _encode_image(self, image_path: str) -> str:
447+
"""Encode image to base64 string.
335448
449+
Args:
450+
image_path: Path to image file.
451+
452+
Returns:
453+
Base64 encoded image string.
454+
455+
Raises:
456+
ValueError: If the file cannot be read or doesn't exist.
457+
"""
458+
try:
459+
with open(image_path, "rb") as image_file:
460+
return base64.b64encode(image_file.read()).decode("utf-8")
461+
except FileNotFoundError:
462+
raise ValueError(f"Image file not found: {image_path}")
463+
except PermissionError:
464+
raise ValueError(f"Permission denied when reading image file: {image_path}")
465+
except Exception as e:
466+
raise ValueError(f"Error encoding image {image_path}: {str(e)}")
467+
468+
def _prepare_image_content(
469+
self, image_source: Union[str, Dict[str, Any]], detail: str = "auto"
470+
) -> Dict[str, Any]:
471+
"""Prepare image content for API request.
472+
473+
Args:
474+
image_source: Either a path to local image or a URL.
475+
detail: Image detail level ('auto', 'low', or 'high').
476+
477+
Returns:
478+
Formatted image content for API request.
479+
"""
480+
if isinstance(image_source, str):
481+
if image_source.startswith(("http://", "https://")):
482+
return {
483+
"type": "image_url",
484+
"image_url": {"url": image_source, "detail": detail},
485+
}
486+
else:
487+
base64_image = self._encode_image(image_source)
488+
return {
489+
"type": "image_url",
490+
"image_url": {
491+
"url": f"data:image/jpeg;base64,{base64_image}",
492+
"detail": detail,
493+
},
494+
}
495+
return image_source
496+
497+
498+
# Example usage:
336499
# if __name__ == "__main__":
337500
# from adalflow.core import Generator
338501
# from adalflow.utils import setup_env, get_logger
339-
502+
#
340503
# log = get_logger(level="DEBUG")
341-
504+
#
342505
# setup_env()
343506
# prompt_kwargs = {"input_str": "What is the meaning of life?"}
344-
507+
#
345508
# gen = Generator(
346509
# model_client=OpenAIClient(),
347510
# model_kwargs={"model": "gpt-3.5-turbo", "stream": True},
348511
# )
349512
# gen_response = gen(prompt_kwargs)
350513
# print(f"gen_response: {gen_response}")
351-
514+
#
352515
# for genout in gen_response.data:
353516
# print(f"genout: {genout}")

adalflow/adalflow/core/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class ModelType(Enum):
5858
EMBEDDER = auto()
5959
LLM = auto()
6060
RERANKER = auto() # ranking model
61+
IMAGE_GENERATION = auto() # image generation models like DALL-E
6162
UNDEFINED = auto()
6263

6364

0 commit comments

Comments
 (0)