Skip to content

Commit d121c18

Browse files
committed
fix kwargs implementation
1 parent 4145595 commit d121c18

File tree

3 files changed

+31
-55
lines changed

3 files changed

+31
-55
lines changed

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ class OpenAIClient(ModelClient):
106106
Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client.
107107
(2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project.
108108
109+
Args:
110+
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
111+
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
112+
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
113+
model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
114+
109115
Note:
110116
We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
111117
We do not know how OpenAI is doing the formating or what prompt they have added.
@@ -120,14 +126,9 @@ class OpenAIClient(ModelClient):
120126
- prompt: Text description of the image to generate
121127
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
122128
- quality: "standard" or "hd" (DALL-E 3 only)
123-
- n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2)
129+
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2)
124130
- response_format: "url" or "b64_json"
125131
126-
Args:
127-
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
128-
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
129-
Default is `get_first_message_content`.
130-
131132
References:
132133
- Embeddings models: https://platform.openai.com/docs/guides/embeddings
133134
- Chat models: https://platform.openai.com/docs/guides/text-generation
@@ -141,11 +142,15 @@ def __init__(
141142
api_key: Optional[str] = None,
142143
chat_completion_parser: Callable[[Completion], Any] = None,
143144
input_type: Literal["text", "messages"] = "text",
145+
model_type: ModelType = ModelType.LLM,
144146
):
145147
r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
146148
147149
Args:
148150
api_key (Optional[str], optional): OpenAI API key. Defaults to None.
151+
chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
152+
input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
153+
model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
149154
"""
150155
super().__init__()
151156
self._api_key = api_key
@@ -155,6 +160,7 @@ def __init__(
155160
chat_completion_parser or get_first_message_content
156161
)
157162
self._input_type = input_type
163+
self.model_type = model_type
158164

159165
def init_sync_client(self):
160166
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
@@ -229,7 +235,6 @@ def convert_inputs_to_api_kwargs(
229235
self,
230236
input: Optional[Any] = None,
231237
model_kwargs: Dict = {},
232-
model_type: ModelType = ModelType.UNDEFINED,
233238
) -> Dict:
234239
r"""
235240
Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
@@ -254,21 +259,20 @@ def convert_inputs_to_api_kwargs(
254259
- mask: Path to the mask image
255260
For variations (DALL-E 2 only):
256261
- image: Path to the input image
257-
model_type: The type of model (EMBEDDER, LLM, or IMAGE_GENERATION)
258262
259263
Returns:
260264
Dict: API-specific kwargs for the model call
261265
"""
262266

263267
final_model_kwargs = model_kwargs.copy()
264-
if model_type == ModelType.EMBEDDER:
268+
if self.model_type == ModelType.EMBEDDER:
265269
if isinstance(input, str):
266270
input = [input]
267271
# convert input to input
268272
if not isinstance(input, Sequence):
269273
raise TypeError("input must be a sequence of text")
270274
final_model_kwargs["input"] = input
271-
elif model_type == ModelType.LLM:
275+
elif self.model_type == ModelType.LLM:
272276
# convert input to messages
273277
messages: List[Dict[str, str]] = []
274278
images = final_model_kwargs.pop("images", None)
@@ -313,7 +317,7 @@ def convert_inputs_to_api_kwargs(
313317
else:
314318
messages.append({"role": "system", "content": input})
315319
final_model_kwargs["messages"] = messages
316-
elif model_type == ModelType.IMAGE_GENERATION:
320+
elif self.model_type == ModelType.IMAGE_GENERATION:
317321
# For image generation, input is the prompt
318322
final_model_kwargs["prompt"] = input
319323
# Ensure model is specified
@@ -358,7 +362,7 @@ def convert_inputs_to_api_kwargs(
358362
else:
359363
raise ValueError(f"Invalid operation: {operation}")
360364
else:
361-
raise ValueError(f"model_type {model_type} is not supported")
365+
raise ValueError(f"model_type {self.model_type} is not supported")
362366
return final_model_kwargs
363367

364368
def parse_image_generation_response(self, response: List[Image]) -> GeneratorOutput:

adalflow/adalflow/core/generator.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,11 @@ class Generator(GradComponent, CachedEngine, CallbackManager):
7070
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`.
7171
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None.
7272
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None.
73-
trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to [].
74-
75-
Note:
76-
The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output.
77-
And we encourage you to only use it to parse the response to data format you will use later.
73+
name (Optional[str], optional): The name of the generator. Defaults to None.
74+
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
75+
use_cache (bool, optional): Whether to use cache. Defaults to False.
7876
"""
7977

80-
model_type: ModelType = ModelType.LLM
81-
model_client: ModelClient # for better type checking
82-
83-
_use_cache: bool = False
84-
_kwargs: Dict[str, Any] = (
85-
{}
86-
) # to create teacher generator from student TODO: might reaccess this
87-
8878
def __init__(
8979
self,
9080
*,
@@ -100,8 +90,6 @@ def __init__(
10090
# args for the cache
10191
cache_path: Optional[str] = None,
10292
use_cache: bool = False,
103-
# args for model type
104-
model_type: ModelType = ModelType.LLM,
10593
) -> None:
10694
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables:
10795
- task_desc_str
@@ -112,17 +100,6 @@ def __init__(
112100
- steps_str
113101
You can preset the prompt kwargs to fill in the variables in the prompt using prompt_kwargs.
114102
But you can replace the prompt and set any variables you want and use the prompt_kwargs to fill in the variables.
115-
116-
Args:
117-
model_client (ModelClient): The model client to use for the generator.
118-
model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. Please refer to :ref:`ModelClient<components-model_client>` for the details on how to set the model_kwargs for your specific model if it is from our library.
119-
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`.
120-
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None.
121-
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None.
122-
name (Optional[str], optional): The name of the generator. Defaults to None.
123-
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
124-
use_cache (bool, optional): Whether to use cache. Defaults to False.
125-
model_type (ModelType, optional): The type of model (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
126103
"""
127104

128105
if not isinstance(model_client, ModelClient):
@@ -134,7 +111,6 @@ def __init__(
134111
template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT
135112

136113
# create the cache path and initialize the cache engine
137-
138114
self.set_cache_path(
139115
cache_path, model_client, model_kwargs.get("model", "default")
140116
)
@@ -146,7 +122,7 @@ def __init__(
146122
CallbackManager.__init__(self)
147123

148124
self.name = name or self.__class__.__name__
149-
self.model_type = model_type
125+
self.model_type = model_client.model_type # Get model type from client
150126

151127
self._init_prompt(template, prompt_kwargs)
152128

@@ -177,7 +153,6 @@ def __init__(
177153
"name": name,
178154
"cache_path": cache_path,
179155
"use_cache": use_cache,
180-
"model_type": model_type,
181156
}
182157
self._teacher: Optional["Generator"] = None
183158
self._trace_api_kwargs: Dict[str, Any] = (
@@ -351,7 +326,6 @@ def _pre_call(self, prompt_kwargs: Dict, model_kwargs: Dict) -> Dict[str, Any]:
351326
api_kwargs = self.model_client.convert_inputs_to_api_kwargs(
352327
input=prompt_str,
353328
model_kwargs=composed_model_kwargs,
354-
model_type=self.model_type,
355329
)
356330
return api_kwargs
357331

tutorials/multimodal_client_testing_examples.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
def test_basic_generation():
2727
"""Test basic text generation"""
28-
client = OpenAIClient()
28+
client = OpenAIClient() # Default model_type is LLM
2929
gen = Generator(
3030
model_client=client,
3131
model_kwargs={
@@ -40,7 +40,7 @@ def test_basic_generation():
4040

4141
def test_invalid_image_url():
4242
"""Test Generator output with invalid image URL"""
43-
client = OpenAIClient()
43+
client = OpenAIClient() # Default model_type is LLM
4444
gen = Generator(
4545
model_client=client,
4646
model_kwargs={
@@ -56,16 +56,15 @@ def test_invalid_image_url():
5656

5757
def test_invalid_image_generation():
5858
"""Test DALL-E generation with invalid parameters"""
59-
client = OpenAIClient()
59+
client = OpenAIClient(model_type=ModelType.IMAGE_GENERATION)
6060
gen = Generator(
6161
model_client=client,
6262
model_kwargs={
6363
"model": "dall-e-3",
6464
"size": "invalid_size", # Invalid size parameter
6565
"quality": "standard",
6666
"n": 1
67-
},
68-
model_type=ModelType.IMAGE_GENERATION
67+
}
6968
)
7069

7170
print("\n=== Testing Invalid DALL-E Parameters ===")
@@ -74,11 +73,10 @@ def test_invalid_image_generation():
7473

7574
def test_vision_and_generation():
7675
"""Test both vision analysis and image generation"""
77-
client = OpenAIClient()
78-
79-
# 1. Test Vision Analysis
76+
# 1. Test Vision Analysis with LLM client
77+
vision_client = OpenAIClient() # Default model_type is LLM
8078
vision_gen = Generator(
81-
model_client=client,
79+
model_client=vision_client,
8280
model_kwargs={
8381
"model": "gpt-4o-mini",
8482
"images": "https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png",
@@ -90,16 +88,16 @@ def test_vision_and_generation():
9088
print("\n=== Vision Analysis ===")
9189
print(f"Description: {vision_response.raw_response}")
9290

93-
# 2. Test DALL-E Image Generation
91+
# 2. Test DALL-E Image Generation with IMAGE_GENERATION client
92+
dalle_client = OpenAIClient(model_type=ModelType.IMAGE_GENERATION)
9493
dalle_gen = Generator(
95-
model_client=client,
94+
model_client=dalle_client,
9695
model_kwargs={
9796
"model": "dall-e-3",
9897
"size": "1024x1024",
9998
"quality": "standard",
10099
"n": 1
101-
},
102-
model_type=ModelType.IMAGE_GENERATION
100+
}
103101
)
104102

105103
# For image generation, input_str becomes the prompt

0 commit comments

Comments
 (0)