diff --git a/adalflow/adalflow/components/model_client/openai_client.py b/adalflow/adalflow/components/model_client/openai_client.py index 1e22d5ef4..63461c2a3 100644 --- a/adalflow/adalflow/components/model_client/openai_client.py +++ b/adalflow/adalflow/components/model_client/openai_client.py @@ -106,6 +106,11 @@ class OpenAIClient(ModelClient): Users (1) simplify use ``Embedder`` and ``Generator`` components by passing OpenAIClient() as the model_client. (2) can use this as an example to create their own API client or extend this class(copying and modifing the code) in their own project. + Args: + api_key (Optional[str], optional): OpenAI API key. Defaults to None. + chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. + input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text". + Note: We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API. We do not know how OpenAI is doing the formating or what prompt they have added. @@ -120,14 +125,9 @@ class OpenAIClient(ModelClient): - prompt: Text description of the image to generate - size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 - quality: "standard" or "hd" (DALL-E 3 only) - - n: Number of images to generate (1 for DALL-E 3, 1-10 for DALL-E 2) + - n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2) - response_format: "url" or "b64_json" - Args: - api_key (Optional[str], optional): OpenAI API key. Defaults to None. - chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. - Default is `get_first_message_content`. - References: - Embeddings models: https://platform.openai.com/docs/guides/embeddings - Chat models: https://platform.openai.com/docs/guides/text-generation @@ -146,6 +146,8 @@ def __init__( Args: api_key (Optional[str], optional): OpenAI API key. Defaults to None. + chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None. + input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text". """ super().__init__() self._api_key = api_key @@ -229,7 +231,7 @@ def convert_inputs_to_api_kwargs( self, input: Optional[Any] = None, model_kwargs: Dict = {}, - model_type: ModelType = ModelType.UNDEFINED, + model_type: ModelType = ModelType.UNDEFINED, # Now required in practice ) -> Dict: r""" Specify the API input type and output api_kwargs that will be used in _call and _acall methods. @@ -243,11 +245,24 @@ def convert_inputs_to_api_kwargs( - images: Optional image source(s) as path, URL, or list of them - detail: Image detail level ('auto', 'low', or 'high'), defaults to 'auto' - model: The model to use (must support multimodal inputs if images are provided) - model_type: The type of model (EMBEDDER or LLM) + For image generation: + - model: "dall-e-3" or "dall-e-2" + - size: "1024x1024", "1024x1792", or "1792x1024" for DALL-E 3; "256x256", "512x512", or "1024x1024" for DALL-E 2 + - quality: "standard" or "hd" (DALL-E 3 only) + - n: Number of images (1 for DALL-E 3, 1-10 for DALL-E 2) + - response_format: "url" or "b64_json" + For image edits (DALL-E 2 only): + - image: Path to the input image + - mask: Path to the mask image + For variations (DALL-E 2 only): + - image: Path to the input image + model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required. Returns: Dict: API-specific kwargs for the model call """ + if model_type == ModelType.UNDEFINED: + raise ValueError("model_type must be specified") final_model_kwargs = model_kwargs.copy() if model_type == ModelType.EMBEDDER: @@ -308,24 +323,43 @@ def convert_inputs_to_api_kwargs( # Ensure model is specified if "model" not in final_model_kwargs: raise ValueError("model must be specified for image generation") - # Set defaults for DALL-E 3 if not specified - final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") - final_model_kwargs["quality"] = final_model_kwargs.get( - "quality", "standard" - ) - final_model_kwargs["n"] = final_model_kwargs.get("n", 1) - final_model_kwargs["response_format"] = final_model_kwargs.get( - "response_format", "url" - ) - - # Handle image edits and variations - image = final_model_kwargs.get("image") - if isinstance(image, str) and os.path.isfile(image): - final_model_kwargs["image"] = self._encode_image(image) - - mask = final_model_kwargs.get("mask") - if isinstance(mask, str) and os.path.isfile(mask): - final_model_kwargs["mask"] = self._encode_image(mask) + # Set defaults for image generation + if "operation" not in final_model_kwargs: + final_model_kwargs["operation"] = "generate" # Default operation + + operation = final_model_kwargs.pop("operation") + + if operation == "generate": + # Set defaults for DALL-E 3 if not specified + final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") + final_model_kwargs["quality"] = final_model_kwargs.get("quality", "standard") + final_model_kwargs["n"] = final_model_kwargs.get("n", 1) + final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url") + + elif operation in ["edit", "variation"]: + if "model" not in final_model_kwargs or final_model_kwargs["model"] != "dall-e-2": + raise ValueError(f"{operation} operation is only available with DALL-E 2") + + # Handle image input + image_path = final_model_kwargs.get("image") + if not image_path or not os.path.isfile(image_path): + raise ValueError(f"Valid image path must be provided for {operation}") + final_model_kwargs["image"] = open(image_path, "rb") + + # Handle mask for edit operation + if operation == "edit": + mask_path = final_model_kwargs.get("mask") + if not mask_path or not os.path.isfile(mask_path): + raise ValueError("Valid mask path must be provided for edit operation") + final_model_kwargs["mask"] = open(mask_path, "rb") + + # Set defaults + final_model_kwargs["size"] = final_model_kwargs.get("size", "1024x1024") + final_model_kwargs["n"] = final_model_kwargs.get("n", 1) + final_model_kwargs["response_format"] = final_model_kwargs.get("response_format", "url") + + else: + raise ValueError(f"Invalid operation: {operation}") else: raise ValueError(f"model_type {model_type} is not supported") return final_model_kwargs @@ -361,6 +395,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE """ kwargs is the combined input and model_kwargs. Support streaming call. """ + if model_type == ModelType.UNDEFINED: + raise ValueError("model_type must be specified") + log.info(f"api_kwargs: {api_kwargs}") if model_type == ModelType.EMBEDDER: return self.sync_client.embeddings.create(**api_kwargs) @@ -371,18 +408,25 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE return self.sync_client.chat.completions.create(**api_kwargs) return self.sync_client.chat.completions.create(**api_kwargs) elif model_type == ModelType.IMAGE_GENERATION: - # Determine which image API to call based on the presence of image/mask - if "image" in api_kwargs: - if "mask" in api_kwargs: - # Image edit + operation = api_kwargs.pop("operation", "generate") + + try: + if operation == "generate": + response = self.sync_client.images.generate(**api_kwargs) + elif operation == "edit": response = self.sync_client.images.edit(**api_kwargs) - else: - # Image variation + elif operation == "variation": response = self.sync_client.images.create_variation(**api_kwargs) - else: - # Image generation - response = self.sync_client.images.generate(**api_kwargs) - return response.data + else: + raise ValueError(f"Invalid operation: {operation}") + + return response.data + finally: + # Clean up file handles if they exist + if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"): + api_kwargs["image"].close() + if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"): + api_kwargs["mask"].close() else: raise ValueError(f"model_type {model_type} is not supported") @@ -403,6 +447,9 @@ async def acall( """ kwargs is the combined input and model_kwargs """ + if model_type == ModelType.UNDEFINED: + raise ValueError("model_type must be specified") + if self.async_client is None: self.async_client = self.init_async_client() if model_type == ModelType.EMBEDDER: @@ -410,20 +457,25 @@ async def acall( elif model_type == ModelType.LLM: return await self.async_client.chat.completions.create(**api_kwargs) elif model_type == ModelType.IMAGE_GENERATION: - # Determine which image API to call based on the presence of image/mask - if "image" in api_kwargs: - if "mask" in api_kwargs: - # Image edit + operation = api_kwargs.pop("operation", "generate") + + try: + if operation == "generate": + response = await self.async_client.images.generate(**api_kwargs) + elif operation == "edit": response = await self.async_client.images.edit(**api_kwargs) + elif operation == "variation": + response = await self.async_client.images.create_variation(**api_kwargs) else: - # Image variation - response = await self.async_client.images.create_variation( - **api_kwargs - ) - else: - # Image generation - response = await self.async_client.images.generate(**api_kwargs) - return response.data + raise ValueError(f"Invalid operation: {operation}") + + return response.data + finally: + # Clean up file handles if they exist + if "image" in api_kwargs and hasattr(api_kwargs["image"], "close"): + api_kwargs["image"].close() + if "mask" in api_kwargs and hasattr(api_kwargs["mask"], "close"): + api_kwargs["mask"].close() else: raise ValueError(f"model_type {model_type} is not supported") diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index baedd8fb7..b2bb072b7 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -70,21 +70,12 @@ class Generator(GradComponent, CachedEngine, CallbackManager): template (Optional[str], optional): The template for the prompt. Defaults to :ref:`DEFAULT_ADALFLOW_SYSTEM_PROMPT`. prompt_kwargs (Optional[Dict], optional): The preset prompt kwargs to fill in the variables in the prompt. Defaults to None. output_processors (Optional[Component], optional): The output processors after model call. It can be a single component or a chained component via ``Sequential``. Defaults to None. - trainable_params (Optional[List[str]], optional): The list of trainable parameters. Defaults to []. - - Note: - The output_processors will be applied to the string output of the model completion. And the result will be stored in the data field of the output. - And we encourage you to only use it to parse the response to data format you will use later. + name (Optional[str], optional): The name of the generator. Defaults to None. + cache_path (Optional[str], optional): The path to save the cache. Defaults to None. + use_cache (bool, optional): Whether to use cache. Defaults to False. + model_type (ModelType, optional): The type of the model. Defaults to ModelType.LLM. """ - model_type: ModelType = ModelType.LLM - model_client: ModelClient # for better type checking - - _use_cache: bool = False - _kwargs: Dict[str, Any] = ( - {} - ) # to create teacher generator from student TODO: might reaccess this - def __init__( self, *, @@ -100,6 +91,7 @@ def __init__( # args for the cache cache_path: Optional[str] = None, use_cache: bool = False, + model_type: ModelType = ModelType.LLM, # Add model_type parameter with default ) -> None: r"""The default prompt is set to the DEFAULT_ADALFLOW_SYSTEM_PROMPT. It has the following variables: - task_desc_str @@ -121,7 +113,6 @@ def __init__( template = template or DEFAULT_ADALFLOW_SYSTEM_PROMPT # create the cache path and initialize the cache engine - self.set_cache_path( cache_path, model_client, model_kwargs.get("model", "default") ) @@ -133,6 +124,7 @@ def __init__( CallbackManager.__init__(self) self.name = name or self.__class__.__name__ + self.model_type = model_type # Use the passed model_type instead of getting from client self._init_prompt(template, prompt_kwargs) diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py index a15c302a5..e6631f10f 100644 --- a/adalflow/tests/test_generator.py +++ b/adalflow/tests/test_generator.py @@ -15,6 +15,7 @@ from adalflow.core.model_client import ModelClient from adalflow.components.model_client.groq_client import GroqAPIClient from adalflow.tracing import GeneratorStateLogger +from adalflow.core.types import ModelType class TestGenerator(IsolatedAsyncioTestCase): @@ -32,7 +33,7 @@ def setUp(self): ) self.mock_api_client = mock_api_client - self.generator = Generator(model_client=mock_api_client) + self.generator = Generator(model_client=mock_api_client, model_type=ModelType.LLM) self.save_dir = "./tests/log" self.project_name = "TestGenerator" self.filename = "prompt_logger_test.json" @@ -182,7 +183,7 @@ def test_groq_client_call(self, mock_call): template = "Hello, {{ input_str }}!" # Initialize the Generator with the mocked client - generator = Generator(model_client=self.client, template=template) + generator = Generator(model_client=self.client, template=template, model_type=ModelType.LLM) # Call the generator and get the output output = generator.call(prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs) diff --git a/docs/source/tutorials/multimodal_client.rst b/docs/source/tutorials/multimodal_client.rst new file mode 100644 index 000000000..a27547406 --- /dev/null +++ b/docs/source/tutorials/multimodal_client.rst @@ -0,0 +1,107 @@ +Multimodal Client Tutorial +======================= + +This tutorial demonstrates how to use the OpenAI client for different types of tasks: text generation, vision analysis, and image generation. + +Model Types +---------- + +The OpenAI client supports three types of operations: + +1. Text/Chat Completion (``ModelType.LLM``) + - Standard text generation + - Vision analysis (with GPT-4V) +2. Image Generation (``ModelType.IMAGE_GENERATION``) + - DALL-E image generation +3. Embeddings (``ModelType.EMBEDDER``) + - Text embeddings + +Basic Usage +---------- + +The model type is specified when creating a ``Generator`` instance: + +.. code-block:: python + + from adalflow.core import Generator + from adalflow.components.model_client.openai_client import OpenAIClient + from adalflow.core.types import ModelType + + # Create the client + client = OpenAIClient() + + # For text generation + gen = Generator( + model_client=client, + model_kwargs={"model": "gpt-4", "max_tokens": 100}, + model_type=ModelType.LLM # Specify LLM type + ) + response = gen({"input_str": "Hello, world!"}) + +Vision Tasks +----------- + +Vision tasks use ``ModelType.LLM`` since they are handled by GPT-4V: + +.. code-block:: python + + # Vision analysis + vision_gen = Generator( + model_client=client, + model_kwargs={ + "model": "gpt-4o-mini", + "images": "path/to/image.jpg", + "max_tokens": 300, + }, + model_type=ModelType.LLM # Vision uses LLM type + ) + response = vision_gen({"input_str": "What do you see in this image?"}) + +Image Generation +-------------- + +For DALL-E image generation, use ``ModelType.IMAGE_GENERATION``: + +.. code-block:: python + + # Image generation with DALL-E + dalle_gen = Generator( + model_client=client, + model_kwargs={ + "model": "dall-e-3", + "size": "1024x1024", + "quality": "standard", + "n": 1, + }, + model_type=ModelType.IMAGE_GENERATION # Specify image generation type + ) + response = dalle_gen({"input_str": "A cat playing with yarn"}) + +Backward Compatibility +-------------------- + +For backward compatibility with existing code: + +1. ``model_type`` defaults to ``ModelType.LLM`` if not specified +2. Older models that only support text continue to work with ``ModelType.LLM`` +3. The OpenAI client handles the appropriate API endpoints based on the model type + +Error Handling +------------- + +The client includes error handling for: + +1. Invalid model types for operations +2. Invalid image URLs or file paths +3. Unsupported model capabilities +4. API errors and rate limits + +Complete Example +-------------- + +See the complete example in ``tutorials/multimodal_client_testing_examples.py``, which demonstrates: + +1. Basic text generation +2. Vision analysis with image input +3. DALL-E image generation +4. Error handling for invalid inputs \ No newline at end of file diff --git a/tests/test_generator.py b/tests/test_generator.py new file mode 100644 index 000000000..0519ecba6 --- /dev/null +++ b/tests/test_generator.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tutorials/multimodal_client_testing_examples.py b/tutorials/multimodal_client_testing_examples.py index ee3a171dd..7e42a42b6 100644 --- a/tutorials/multimodal_client_testing_examples.py +++ b/tutorials/multimodal_client_testing_examples.py @@ -19,17 +19,13 @@ from adalflow.core.types import ModelType -class ImageGenerator(Generator): - """Generator subclass for image generation.""" - - model_type = ModelType.IMAGE_GENERATION - - def test_basic_generation(): """Test basic text generation""" - client = OpenAIClient() + client = OpenAIClient() # For text/chat completion gen = Generator( - model_client=client, model_kwargs={"model": "gpt-4o-mini", "max_tokens": 100} + model_client=client, + model_kwargs={"model": "gpt-4o-mini", "max_tokens": 100}, + model_type=ModelType.LLM # Explicitly specify model type ) print("\n=== Testing Basic Generation ===") @@ -39,7 +35,7 @@ def test_basic_generation(): def test_invalid_image_url(): """Test Generator output with invalid image URL""" - client = OpenAIClient() + client = OpenAIClient() # For vision tasks gen = Generator( model_client=client, model_kwargs={ @@ -47,6 +43,7 @@ def test_invalid_image_url(): "images": "https://invalid.url/nonexistent.jpg", "max_tokens": 300, }, + model_type=ModelType.LLM # Vision tasks use LLM type ) print("\n=== Testing Invalid Image URL ===") @@ -56,8 +53,8 @@ def test_invalid_image_url(): def test_invalid_image_generation(): """Test DALL-E generation with invalid parameters""" - client = OpenAIClient() - gen = ImageGenerator( + client = OpenAIClient() # For image generation + gen = Generator( model_client=client, model_kwargs={ "model": "dall-e-3", @@ -65,6 +62,7 @@ def test_invalid_image_generation(): "quality": "standard", "n": 1, }, + model_type=ModelType.IMAGE_GENERATION # Specify image generation type ) print("\n=== Testing Invalid DALL-E Parameters ===") @@ -74,16 +72,16 @@ def test_invalid_image_generation(): def test_vision_and_generation(): """Test both vision analysis and image generation""" - client = OpenAIClient() - - # 1. Test Vision Analysis + # 1. Test Vision Analysis with LLM client + vision_client = OpenAIClient() # For vision tasks vision_gen = Generator( - model_client=client, + model_client=vision_client, model_kwargs={ "model": "gpt-4o-mini", "images": "https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png", "max_tokens": 300, }, + model_type=ModelType.LLM # Vision tasks use LLM type ) vision_response = vision_gen( @@ -93,14 +91,16 @@ def test_vision_and_generation(): print(f"Description: {vision_response.raw_response}") # 2. Test DALL-E Image Generation - dalle_gen = ImageGenerator( - model_client=client, + dalle_client = OpenAIClient() # For image generation + dalle_gen = Generator( + model_client=dalle_client, model_kwargs={ "model": "dall-e-3", "size": "1024x1024", "quality": "standard", "n": 1, }, + model_type=ModelType.IMAGE_GENERATION # Specify image generation type ) # For image generation, input_str becomes the prompt