Skip to content

Commit 839d018

Browse files
committed
multimodal kwargs update
1 parent a3d6f0f commit 839d018

File tree

3 files changed

+194
-25
lines changed

3 files changed

+194
-25
lines changed

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,27 +140,27 @@ class OpenAIClient(ModelClient):
140140
def __init__(
141141
self,
142142
api_key: Optional[str] = None,
143+
model_type: ModelType = ModelType.LLM,
143144
chat_completion_parser: Callable[[Completion], Any] = None,
144145
input_type: Literal["text", "messages"] = "text",
145-
model_type: ModelType = ModelType.LLM,
146146
):
147-
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
147+
r"""Initialize the OpenAI client.
148148
149149
Args:
150150
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
151-
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
152-
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
153151
model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
152+
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse chat completions. Defaults to None.
153+
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
154154
"""
155155
super().__init__()
156156
self._api_key = api_key
157+
self.model_type = model_type
157158
self.sync_client = self.init_sync_client()
158159
self.async_client = None # only initialize if the async call is called
159160
self.chat_completion_parser = (
160161
chat_completion_parser or get_first_message_content
161162
)
162163
self._input_type = input_type
163-
self.model_type = model_type
164164

165165
def init_sync_client(self):
166166
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
@@ -235,6 +235,7 @@ def convert_inputs_to_api_kwargs(
235235
self,
236236
input: Optional[Any] = None,
237237
model_kwargs: Dict = {},
238+
model_type: ModelType = ModelType.UNDEFINED,
238239
) -> Dict:
239240
r"""
240241
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
@@ -259,6 +260,7 @@ def convert_inputs_to_api_kwargs(
259260
- mask: Path to the mask image
260261
For variations (DALL-E 2 only):
261262
- image: Path to the input image
263+
model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION)
262264
263265
Returns:
264266
Dict: API-specific kwargs for the model call
@@ -397,6 +399,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
397399
"""
398400
kwargs is the combined input and model_kwargs. Support streaming call.
399401
"""
402+
# Use self.model_type if no model_type is provided or if UNDEFINED
403+
model_type = self.model_type if model_type == ModelType.UNDEFINED else model_type
404+
400405
log.info(f"api_kwargs: {api_kwargs}")
401406
if model_type == ModelType.EMBEDDER:
402407
return self.sync_client.embeddings.create(**api_kwargs)
@@ -446,6 +451,9 @@ async def acall(
446451
"""
447452
kwargs is the combined input and model_kwargs
448453
"""
454+
# Use self.model_type if no model_type is provided or if UNDEFINED
455+
model_type = self.model_type if model_type == ModelType.UNDEFINED else model_type
456+
449457
if self.async_client is None:
450458
self.async_client = self.init_async_client()
451459
if model_type == ModelType.EMBEDDER:

adalflow/tests/test_openai_client.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ def getenv_side_effect(key):
1818

1919
class TestOpenAIClient(unittest.IsolatedAsyncioTestCase):
2020
def setUp(self):
21-
self.client = OpenAIClient(api_key="fake_api_key")
21+
# Default client for LLM tests
22+
self.client = OpenAIClient(api_key="fake_api_key", model_type=ModelType.LLM)
23+
24+
# Client for image generation tests
25+
self.image_client = OpenAIClient(api_key="fake_api_key", model_type=ModelType.IMAGE_GENERATION)
26+
2227
self.mock_response = {
2328
"id": "cmpl-3Q8Z5J9Z1Z5z5",
2429
"created": 1635820005,
@@ -152,7 +157,6 @@ def test_convert_inputs_to_api_kwargs_with_images(self):
152157
result = self.client.convert_inputs_to_api_kwargs(
153158
input="Describe this image",
154159
model_kwargs=model_kwargs,
155-
model_type=ModelType.LLM,
156160
)
157161
expected_content = [
158162
{"type": "text", "text": "Describe this image"},
@@ -175,7 +179,6 @@ def test_convert_inputs_to_api_kwargs_with_images(self):
175179
result = self.client.convert_inputs_to_api_kwargs(
176180
input="Compare these images",
177181
model_kwargs=model_kwargs,
178-
model_type=ModelType.LLM,
179182
)
180183
expected_content = [
181184
{"type": "text", "text": "Compare these images"},
@@ -202,15 +205,13 @@ async def test_acall_llm(self, MockAsyncOpenAI):
202205
MockAsyncOpenAI.return_value = mock_async_client
203206

204207
# Mock the response
205-
206208
mock_async_client.chat.completions.create = AsyncMock(
207209
return_value=self.mock_response
208210
)
209211

210212
# Call the _acall method
211-
212213
result = await self.client.acall(
213-
api_kwargs=self.api_kwargs, model_type=ModelType.LLM
214+
api_kwargs=self.api_kwargs,
214215
)
215216

216217
# Assertions
@@ -236,7 +237,7 @@ def test_call(self, MockSyncOpenAI, mock_init_sync_client):
236237
self.client.sync_client = mock_sync_client
237238

238239
# Call the call method
239-
result = self.client.call(api_kwargs=self.api_kwargs, model_type=ModelType.LLM)
240+
result = self.client.call(api_kwargs=self.api_kwargs)
240241

241242
# Assertions
242243
mock_sync_client.chat.completions.create.assert_called_once_with(
@@ -264,7 +265,7 @@ async def test_acall_llm_with_vision(self, MockAsyncOpenAI):
264265

265266
# Call the _acall method with vision model
266267
result = await self.client.acall(
267-
api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM
268+
api_kwargs=self.vision_api_kwargs,
268269
)
269270

270271
# Assertions
@@ -293,7 +294,7 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
293294

294295
# Call the call method with vision model
295296
result = self.client.call(
296-
api_kwargs=self.vision_api_kwargs, model_type=ModelType.LLM
297+
api_kwargs=self.vision_api_kwargs,
297298
)
298299

299300
# Assertions
@@ -314,10 +315,9 @@ def test_call_with_vision(self, MockSyncOpenAI, mock_init_sync_client):
314315

315316
def test_convert_inputs_to_api_kwargs_for_image_generation(self):
316317
# Test basic image generation
317-
result = self.client.convert_inputs_to_api_kwargs(
318+
result = self.image_client.convert_inputs_to_api_kwargs(
318319
input="a white siamese cat",
319320
model_kwargs={"model": "dall-e-3"},
320-
model_type=ModelType.IMAGE_GENERATION,
321321
)
322322
self.assertEqual(result["prompt"], "a white siamese cat")
323323
self.assertEqual(result["model"], "dall-e-3")
@@ -335,14 +335,13 @@ def test_convert_inputs_to_api_kwargs_for_image_generation(self):
335335
with open(test_mask, "wb") as f:
336336
f.write(b"fake mask content")
337337

338-
result = self.client.convert_inputs_to_api_kwargs(
338+
result = self.image_client.convert_inputs_to_api_kwargs(
339339
input="a white siamese cat",
340340
model_kwargs={
341341
"model": "dall-e-2",
342342
"image": test_image,
343343
"mask": test_mask,
344344
},
345-
model_type=ModelType.IMAGE_GENERATION,
346345
)
347346
self.assertEqual(result["prompt"], "a white siamese cat")
348347
self.assertEqual(result["model"], "dall-e-2")
@@ -366,9 +365,8 @@ async def test_acall_image_generation(self, MockAsyncOpenAI):
366365
)
367366

368367
# Call the acall method with image generation
369-
result = await self.client.acall(
368+
result = await self.image_client.acall(
370369
api_kwargs=self.image_generation_kwargs,
371-
model_type=ModelType.IMAGE_GENERATION,
372370
)
373371

374372
# Assertions
@@ -379,7 +377,7 @@ async def test_acall_image_generation(self, MockAsyncOpenAI):
379377
self.assertEqual(result, self.mock_image_response)
380378

381379
# Test parse_image_generation_response
382-
output = self.client.parse_image_generation_response(result)
380+
output = self.image_client.parse_image_generation_response(result)
383381
self.assertTrue(isinstance(output, GeneratorOutput))
384382
self.assertEqual(output.data, "https://example.com/generated_image.jpg")
385383

@@ -398,12 +396,11 @@ def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
398396
)
399397

400398
# Set the sync client
401-
self.client.sync_client = mock_sync_client
399+
self.image_client.sync_client = mock_sync_client
402400

403401
# Call the call method with image generation
404-
result = self.client.call(
402+
result = self.image_client.call(
405403
api_kwargs=self.image_generation_kwargs,
406-
model_type=ModelType.IMAGE_GENERATION,
407404
)
408405

409406
# Assertions
@@ -413,7 +410,7 @@ def test_call_image_generation(self, MockSyncOpenAI, mock_init_sync_client):
413410
self.assertEqual(result, self.mock_image_response)
414411

415412
# Test parse_image_generation_response
416-
output = self.client.parse_image_generation_response(result)
413+
output = self.image_client.parse_image_generation_response(result)
417414
self.assertTrue(isinstance(output, GeneratorOutput))
418415
self.assertEqual(output.data, "https://example.com/generated_image.jpg")
419416

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
Multimodal Client Tutorial
2+
=======================
3+
4+
This tutorial demonstrates how to use the OpenAI client for different types of tasks: text generation, vision analysis, and image generation.
5+
6+
Basic Setup
7+
----------
8+
9+
First, make sure you have your OpenAI API key set in your environment:
10+
11+
.. code-block:: bash
12+
13+
export OPENAI_API_KEY='your_api_key_here'
14+
15+
The OpenAI client supports three different model types:
16+
17+
- ``ModelType.LLM`` - For text generation and vision tasks (default)
18+
- ``ModelType.IMAGE_GENERATION`` - For DALL-E image generation
19+
- ``ModelType.EMBEDDER`` - For text embeddings
20+
21+
Note that most recent OpenAI models (like GPT-4) support both text and vision tasks by default, so you can use the same client for both.
22+
23+
Text and Vision Tasks
24+
------------------
25+
26+
For text generation and vision tasks, you can use the default ``ModelType.LLM`` with any OpenAI multimodal model:
27+
28+
.. code-block:: python
29+
30+
from adalflow.core import Generator
31+
from adalflow.components.model_client import OpenAIClient
32+
from adalflow.core.types import ModelType
33+
34+
# Default model_type is LLM
35+
client = OpenAIClient()
36+
generator = Generator(
37+
model_client=client,
38+
model_kwargs={"model": "gpt-4", "max_tokens": 100}
39+
)
40+
41+
# Text generation
42+
text_response = generator({"input_str": "Hello, world!"})
43+
print(f"Text Response: {text_response.raw_response}")
44+
45+
# Vision analysis with the same client
46+
vision_response = generator(
47+
prompt_kwargs={"input_str": "What do you see in this image?"},
48+
model_kwargs={
49+
"model": "gpt-4", # Same model can handle both text and images
50+
"images": "https://example.com/image.jpg",
51+
"max_tokens": 300,
52+
},
53+
)
54+
print(f"Vision Response: {vision_response.raw_response}")
55+
56+
Image Generation
57+
--------------
58+
59+
For DALL-E image generation, explicitly set ``ModelType.IMAGE_GENERATION``:
60+
61+
.. code-block:: python
62+
63+
dalle_client = OpenAIClient(model_type=ModelType.IMAGE_GENERATION)
64+
dalle_gen = Generator(
65+
model_client=dalle_client,
66+
model_kwargs={
67+
"model": "dall-e-3",
68+
"size": "1024x1024",
69+
"quality": "standard",
70+
"n": 1,
71+
},
72+
)
73+
74+
# For image generation, input_str becomes the prompt
75+
response = dalle_gen(
76+
{"input_str": "A happy siamese cat playing with a red ball of yarn"}
77+
)
78+
print(f"Generated Image URL: {response.data}")
79+
80+
Error Handling
81+
------------
82+
83+
The client includes built-in error handling for common issues:
84+
85+
1. Invalid Image URLs:
86+
87+
.. code-block:: python
88+
89+
# The client will properly handle invalid image URLs
90+
gen = Generator(
91+
model_client=client,
92+
model_kwargs={
93+
"model": "gpt-4",
94+
"images": "https://invalid.url/nonexistent.jpg",
95+
},
96+
)
97+
response = gen({"input_str": "What do you see?"})
98+
# Will return GeneratorOutput with error information
99+
100+
2. Invalid Parameters:
101+
102+
.. code-block:: python
103+
104+
# The client will catch invalid parameters
105+
gen = Generator(
106+
model_client=dalle_client,
107+
model_kwargs={
108+
"model": "dall-e-3",
109+
"size": "invalid_size", # Invalid size
110+
},
111+
)
112+
response = gen({"input_str": "A cat"})
113+
# Will return GeneratorOutput with error information
114+
115+
Response Structure
116+
---------------
117+
118+
All responses are returned as ``GeneratorOutput`` objects with these fields:
119+
120+
- ``data``: The processed output (e.g., generated text, image URL)
121+
- ``raw_response``: The raw response string
122+
- ``error``: Any error messages (None if successful)
123+
- ``usage``: Token usage information for text generation
124+
- ``metadata``: Additional metadata (if any)
125+
126+
For example:
127+
128+
.. code-block:: python
129+
130+
# Text/Vision response
131+
GeneratorOutput(
132+
data='Hello! How can I assist you today?',
133+
error=None,
134+
usage=CompletionUsage(completion_tokens=10, prompt_tokens=45, total_tokens=55),
135+
raw_response='Hello! How can I assist you today?'
136+
)
137+
138+
# Image generation response
139+
GeneratorOutput(
140+
data='https://...image-url...',
141+
error=None,
142+
raw_response='[Image(url="https://...")]'
143+
)
144+
145+
Best Practices
146+
------------
147+
148+
1. **Model Type Selection**:
149+
- Use the default ``ModelType.LLM`` for text and vision tasks (they use the same models)
150+
- Explicitly set ``ModelType.IMAGE_GENERATION`` only for DALL-E
151+
- Use ``ModelType.EMBEDDER`` for embedding generation
152+
153+
2. **Error Handling**:
154+
- Always check the ``error`` field in the response
155+
- Handle both API errors and invalid parameter errors
156+
157+
3. **Resource Management**:
158+
- Monitor token usage through the ``usage`` field
159+
- Be mindful of image sizes and quality settings for DALL-E
160+
161+
4. **Model Selection**:
162+
- Most recent OpenAI models support both text and vision:
163+
- ``gpt-4`` for both text and vision tasks
164+
- ``dall-e-3`` or ``dall-e-2`` for image generation

0 commit comments

Comments
 (0)