Skip to content

Commit fd80974

Browse files
committed
no subclass for image generation proposal
1 parent 605f1f0 commit fd80974

File tree

3 files changed

+104
-42
lines changed

3 files changed

+104
-42
lines changed

adalflow/adalflow/components/model_client/openai_client.py

Lines changed: 83 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,18 @@ def convert_inputs_to_api_kwargs(
243243
- images: Optional image source(s) as path, URL, or list of them
244244
- detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto'
245245
- model: The model to use (must support multimodal inputs if images are provided)
246-
model_type: The type of model (EMBEDDER or LLM)
246+
For image generation:
247+
- model: "dall-e-3" or "dall-e-2"
248+
- size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2
249+
- quality: "standard" or "hd" (DALL-E 3 only)
250+
- n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2)
251+
- response_format: "url" or "b64_json"
252+
For image edits (DALL-E 2 only):
253+
- image: Path to the input image
254+
- mask: Path to the mask image
255+
For variations (DALL-E 2 only):
256+
- image: Path to the input image
257+
model_type: The type of model (EMBEDDER, LLM, or IMAGE_GENERATION)
247258
248259
Returns:
249260
Dict: API-specific kwargs for the model call
@@ -308,20 +319,44 @@ def convert_inputs_to_api_kwargs(
308319
# Ensure model is specified
309320
if "model" not in final_model_kwargs:
310321
raise ValueError("model must be specified for image generation")
311-
# Set defaults for DALL-E 3 if not specified
312-
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
313-
final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard")
314-
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
315-
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")
316-
317-
# Handle image edits and variations
318-
image = final_model_kwargs.get("image")
319-
if isinstance(image, str) and os.path.isfile(image):
320-
final_model_kwargs["image"] = self._encode_image(image)
321322

322-
mask = final_model_kwargs.get("mask")
323-
if isinstance(mask, str) and os.path.isfile(mask):
324-
final_model_kwargs["mask"] = self._encode_image(mask)
323+
# Set defaults for image generation
324+
if "operation" not in final_model_kwargs:
325+
final_model_kwargs["operation"] = "generate" # Default operation
326+
327+
operation = final_model_kwargs.pop("operation")
328+
329+
if operation == "generate":
330+
# Set defaults for DALL-E 3 if not specified
331+
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
332+
final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard")
333+
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
334+
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")
335+
336+
elif operation in ["edit", "variation"]:
337+
if "model" not in final_model_kwargs or final_model_kwargs["model"] != "dall-e-2":
338+
raise ValueError(f"{operation} operation is only available with DALL-E 2")
339+
340+
# Handle image input
341+
image_path = final_model_kwargs.get("image")
342+
if not image_path or not os.path.isfile(image_path):
343+
raise ValueError(f"Valid image path must be provided for {operation}")
344+
final_model_kwargs["image"] = open(image_path, "rb")
345+
346+
# Handle mask for edit operation
347+
if operation == "edit":
348+
mask_path = final_model_kwargs.get("mask")
349+
if not mask_path or not os.path.isfile(mask_path):
350+
raise ValueError("Valid mask path must be provided for edit operation")
351+
final_model_kwargs["mask"] = open(mask_path, "rb")
352+
353+
# Set defaults
354+
final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024")
355+
final_model_kwargs["n"] = final_model_kwargs.get("n", 1)
356+
final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url")
357+
358+
else:
359+
raise ValueError(f"Invalid operation: {operation}")
325360
else:
326361
raise ValueError(f"model_type {model_type} is not supported")
327362
return final_model_kwargs
@@ -371,18 +406,25 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
371406
return self.sync_client.chat.completions.create(**api_kwargs)
372407
return self.sync_client.chat.completions.create(**api_kwargs)
373408
elif model_type == ModelType.IMAGE_GENERATION:
374-
# Determine which image API to call based on the presence of image/mask
375-
if "image" in api_kwargs:
376-
if "mask" in api_kwargs:
377-
# Image edit
409+
operation = api_kwargs.pop("operation", "generate")
410+
411+
try:
412+
if operation == "generate":
413+
response = self.sync_client.images.generate(**api_kwargs)
414+
elif operation == "edit":
378415
response = self.sync_client.images.edit(**api_kwargs)
379-
else:
380-
# Image variation
416+
elif operation == "variation":
381417
response = self.sync_client.images.create_variation(**api_kwargs)
382-
else:
383-
# Image generation
384-
response = self.sync_client.images.generate(**api_kwargs)
385-
return response.data
418+
else:
419+
raise ValueError(f"Invalid operation: {operation}")
420+
421+
return response.data
422+
finally:
423+
# Clean up file handles if they exist
424+
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"):
425+
api_kwargs["image"].close()
426+
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"):
427+
api_kwargs["mask"].close()
386428
else:
387429
raise ValueError(f"model_type {model_type} is not supported")
388430

@@ -410,18 +452,25 @@ async def acall(
410452
elif model_type == ModelType.LLM:
411453
return await self.async_client.chat.completions.create(**api_kwargs)
412454
elif model_type == ModelType.IMAGE_GENERATION:
413-
# Determine which image API to call based on the presence of image/mask
414-
if "image" in api_kwargs:
415-
if "mask" in api_kwargs:
416-
# Image edit
455+
operation = api_kwargs.pop("operation", "generate")
456+
457+
try:
458+
if operation == "generate":
459+
response = await self.async_client.images.generate(**api_kwargs)
460+
elif operation == "edit":
417461
response = await self.async_client.images.edit(**api_kwargs)
418-
else:
419-
# Image variation
462+
elif operation == "variation":
420463
response = await self.async_client.images.create_variation(**api_kwargs)
421-
else:
422-
# Image generation
423-
response = await self.async_client.images.generate(**api_kwargs)
424-
return response.data
464+
else:
465+
raise ValueError(f"Invalid operation: {operation}")
466+
467+
return response.data
468+
finally:
469+
# Clean up file handles if they exist
470+
if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"):
471+
api_kwargs["image"].close()
472+
if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"):
473+
api_kwargs["mask"].close()
425474
else:
426475
raise ValueError(f"model_type {model_type} is not supported")
427476

adalflow/adalflow/core/generator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def __init__(
100100
# args for the cache
101101
cache_path: Optional[str] = None,
102102
use_cache: bool = False,
103+
# args for model type
104+
model_type: ModelType = ModelType.LLM,
103105
) -> None:
104106
r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables:
105107
- task_desc_str
@@ -110,6 +112,17 @@ def __init__(
110112
- steps_str
111113
You can preset the prompt kwargs to fill in the variables in the prompt using prompt_kwargs.
112114
But you can replace the prompt and set any variables you want and use the prompt_kwargs to fill in the variables.
115+
116+
Args:
117+
model_client (ModelClient): The model client to use for the generator.
118+
model_kwargs (Dict[str, Any], optional): The model kwargs to pass to the model client. Defaults to {}. Please refer to :ref:`ModelClient<components-model_client>` for the details on how to set the model_kwargs for your specific model if it is from our library.
119+
template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT<core-default_prompt_template>`.
120+
prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None.
121+
output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None.
122+
name (Optional[str], optional): The name of the generator. Defaults to None.
123+
cache_path (Optional[str], optional): The path to save the cache. Defaults to None.
124+
use_cache (bool, optional): Whether to use cache. Defaults to False.
125+
model_type (ModelType, optional): The type of model (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
113126
"""
114127

115128
if not isinstance(model_client, ModelClient):
@@ -133,6 +146,7 @@ def __init__(
133146
CallbackManager.__init__(self)
134147

135148
self.name = name or self.__class__.__name__
149+
self.model_type = model_type
136150

137151
self._init_prompt(template, prompt_kwargs)
138152

@@ -163,6 +177,7 @@ def __init__(
163177
"name": name,
164178
"cache_path": cache_path,
165179
"use_cache": use_cache,
180+
"model_type": model_type,
166181
}
167182
self._teacher: Optional["Generator"] = None
168183
self._trace_api_kwargs: Dict[str, Any] = (

tutorials/multimodal_client_testing_examples.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
from typing import List
2424
from numpy.linalg import norm
2525

26-
class ImageGenerator(Generator):
27-
"""Generator subclass for image generation."""
28-
model_type = ModelType.IMAGE_GENERATION
29-
3026
def test_basic_generation():
3127
"""Test basic text generation"""
3228
client = OpenAIClient()
@@ -61,14 +57,15 @@ def test_invalid_image_url():
6157
def test_invalid_image_generation():
6258
"""Test DALL-E generation with invalid parameters"""
6359
client = OpenAIClient()
64-
gen = ImageGenerator(
60+
gen = Generator(
6561
model_client=client,
6662
model_kwargs={
6763
"model": "dall-e-3",
6864
"size": "invalid_size", # Invalid size parameter
6965
"quality": "standard",
7066
"n": 1
71-
}
67+
},
68+
model_type=ModelType.IMAGE_GENERATION
7269
)
7370

7471
print("\n=== Testing Invalid DALL-E Parameters ===")
@@ -94,14 +91,15 @@ def test_vision_and_generation():
9491
print(f"Description: {vision_response.raw_response}")
9592

9693
# 2. Test DALL-E Image Generation
97-
dalle_gen = ImageGenerator(
94+
dalle_gen = Generator(
9895
model_client=client,
9996
model_kwargs={
10097
"model": "dall-e-3",
10198
"size": "1024x1024",
10299
"quality": "standard",
103100
"n": 1
104-
}
101+
},
102+
model_type=ModelType.IMAGE_GENERATION
105103
)
106104

107105
# For image generation, input_str becomes the prompt

0 commit comments

Comments
 (0)