Skip to content
Merged
41 changes: 39 additions & 2 deletions autogpt_platform/backend/backend/blocks/ai_image_customizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from enum import Enum
from typing import Literal

Expand All @@ -19,14 +20,28 @@
SchemaField,
)
from backend.integrations.providers import ProviderName
from backend.util.file import MediaFileType
from backend.util.file import MediaFileType, store_media_file


class GeminiImageModel(str, Enum):
NANO_BANANA = "google/nano-banana"
NANO_BANANA_PRO = "google/nano-banana-pro"


class AspectRatio(str, Enum):
MATCH_INPUT_IMAGE = "match_input_image"
ASPECT_1_1 = "1:1"
ASPECT_2_3 = "2:3"
ASPECT_3_2 = "3:2"
ASPECT_3_4 = "3:4"
ASPECT_4_3 = "4:3"
ASPECT_4_5 = "4:5"
ASPECT_5_4 = "5:4"
ASPECT_9_16 = "9:16"
ASPECT_16_9 = "16:9"
ASPECT_21_9 = "21:9"


class OutputFormat(str, Enum):
JPG = "jpg"
PNG = "png"
Expand Down Expand Up @@ -69,6 +84,11 @@ class Input(BlockSchemaInput):
default=[],
title="Input Images",
)
aspect_ratio: AspectRatio = SchemaField(
description="Aspect ratio of the generated image",
default=AspectRatio.MATCH_INPUT_IMAGE,
title="Aspect Ratio",
)
output_format: OutputFormat = SchemaField(
description="Format of the output image",
default=OutputFormat.PNG,
Expand All @@ -92,6 +112,7 @@ def __init__(self):
"prompt": "Make the scene more vibrant and colorful",
"model": GeminiImageModel.NANO_BANANA,
"images": [],
"aspect_ratio": AspectRatio.MATCH_INPUT_IMAGE,
"output_format": OutputFormat.JPG,
"credentials": TEST_CREDENTIALS_INPUT,
},
Expand All @@ -116,11 +137,25 @@ async def run(
**kwargs,
) -> BlockOutput:
try:
# Convert local file paths to Data URIs (base64) so Replicate can access them
processed_images = await asyncio.gather(
*(
store_media_file(
graph_exec_id=graph_exec_id,
file=img,
user_id=user_id,
return_content=True,
)
for img in input_data.images
)
)

result = await self.run_model(
api_key=credentials.api_key,
model_name=input_data.model.value,
prompt=input_data.prompt,
images=input_data.images,
images=processed_images,
aspect_ratio=input_data.aspect_ratio.value,
output_format=input_data.output_format.value,
)
yield "image_url", result
Expand All @@ -133,12 +168,14 @@ async def run_model(
model_name: str,
prompt: str,
images: list[MediaFileType],
aspect_ratio: str,
output_format: str,
) -> MediaFileType:
client = ReplicateClient(api_token=api_key.get_secret_value())

input_params: dict = {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"output_format": output_format,
}

Expand Down
Loading