diff --git a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py index 3ab8acb31da4..83178e924d26 100644 --- a/autogpt_platform/backend/backend/blocks/ai_image_customizer.py +++ b/autogpt_platform/backend/backend/blocks/ai_image_customizer.py @@ -1,3 +1,4 @@ +import asyncio from enum import Enum from typing import Literal @@ -19,7 +20,7 @@ SchemaField, ) from backend.integrations.providers import ProviderName -from backend.util.file import MediaFileType +from backend.util.file import MediaFileType, store_media_file class GeminiImageModel(str, Enum): @@ -27,6 +28,20 @@ class GeminiImageModel(str, Enum): NANO_BANANA_PRO = "google/nano-banana-pro" +class AspectRatio(str, Enum): + MATCH_INPUT_IMAGE = "match_input_image" + ASPECT_1_1 = "1:1" + ASPECT_2_3 = "2:3" + ASPECT_3_2 = "3:2" + ASPECT_3_4 = "3:4" + ASPECT_4_3 = "4:3" + ASPECT_4_5 = "4:5" + ASPECT_5_4 = "5:4" + ASPECT_9_16 = "9:16" + ASPECT_16_9 = "16:9" + ASPECT_21_9 = "21:9" + + class OutputFormat(str, Enum): JPG = "jpg" PNG = "png" @@ -69,6 +84,11 @@ class Input(BlockSchemaInput): default=[], title="Input Images", ) + aspect_ratio: AspectRatio = SchemaField( + description="Aspect ratio of the generated image", + default=AspectRatio.MATCH_INPUT_IMAGE, + title="Aspect Ratio", + ) output_format: OutputFormat = SchemaField( description="Format of the output image", default=OutputFormat.PNG, @@ -92,6 +112,7 @@ def __init__(self): "prompt": "Make the scene more vibrant and colorful", "model": GeminiImageModel.NANO_BANANA, "images": [], + "aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE, "output_format": OutputFormat.JPG, "credentials": TEST_CREDENTIALS_INPUT, }, @@ -116,11 +137,25 @@ async def run( **kwargs, ) -> BlockOutput: try: + # Convert local file paths to Data URIs (base64) so Replicate can access them + processed_images = await asyncio.gather( + *( + store_media_file( + graph_exec_id=graph_exec_id, + file=img, + user_id=user_id, + return_content=True, + ) + for img in input_data.images + ) + ) + result = await self.run_model( api_key=credentials.api_key, model_name=input_data.model.value, prompt=input_data.prompt, - images=input_data.images, + images=processed_images, + aspect_ratio=input_data.aspect_ratio.value, output_format=input_data.output_format.value, ) yield "image_url", result @@ -133,12 +168,14 @@ async def run_model( model_name: str, prompt: str, images: list[MediaFileType], + aspect_ratio: str, output_format: str, ) -> MediaFileType: client = ReplicateClient(api_token=api_key.get_secret_value()) input_params: dict = { "prompt": prompt, + "aspect_ratio": aspect_ratio, "output_format": output_format, }