Skip to content

Commit 3bff071

Browse files
committed
added logic to return metadata for image from api to the ui
1 parent c0554cc commit 3bff071

File tree

9 files changed

+151
-33
lines changed

9 files changed

+151
-33
lines changed

api/llm/agent.py

Lines changed: 78 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import atexit
22
import os
3+
import re
4+
from datetime import datetime
35
from functools import lru_cache
46
from typing import List, Optional
57

@@ -68,9 +70,7 @@ def get_agent():
6870
return _agent_executor
6971

7072

71-
def chat_with_agent(
72-
message: str, user_id: str = "default", selected_images: Optional[List[dict]] = None
73-
) -> str:
73+
def chat_with_agent(message: str, user_id: str = "default", selected_images: Optional[List[dict]] = None) -> tuple[str, Optional[dict]]:
7474
"""
7575
Send a message to the agent and get a response.
7676
@@ -80,7 +80,7 @@ def chat_with_agent(
8080
selected_images: List of selected image objects (optional)
8181
8282
Returns:
83-
The agent's response as a string
83+
Tuple of (agent_response, generated_image_data)
8484
"""
8585
agent = get_agent()
8686

@@ -89,9 +89,7 @@ def chat_with_agent(
8989
if selected_images and len(selected_images) > 0:
9090
image_context = "\n\nSelected Images:\n"
9191
for i, img in enumerate(selected_images, 1):
92-
image_context += (
93-
f"{i}. {img.get('title', 'Untitled')} (ID: {img.get('id', 'unknown')})\n"
94-
)
92+
image_context += f"{i}. {img.get('title', 'Untitled')} (ID: {img.get('id', 'unknown')})\n"
9593
image_context += f" Type: {img.get('type', 'unknown')}\n"
9694
image_context += f" Description: {img.get('description', 'No description')}\n"
9795
if img.get("url"):
@@ -103,20 +101,86 @@ def chat_with_agent(
103101
config = {"configurable": {"thread_id": user_id}}
104102

105103
# Get response from agent
106-
response = agent.invoke(
107-
{"messages": [{"role": "user", "content": full_message}]}, config=config
108-
)
104+
response = agent.invoke({"messages": [{"role": "user", "content": full_message}]}, config=config)
109105

110106
# Extract the last message from the agent
107+
agent_response = "I'm sorry, I couldn't process your request. Please try again."
108+
generated_image_data = None
109+
111110
if response and "messages" in response and len(response["messages"]) > 0:
112111
last_message = response["messages"][-1]
113112
# Handle both AIMessage objects and dictionaries
114113
if hasattr(last_message, "content"):
115-
return last_message.content
114+
agent_response = last_message.content
116115
elif isinstance(last_message, dict) and "content" in last_message:
117-
return last_message["content"]
118-
119-
return "I'm sorry, I couldn't process your request. Please try again."
116+
agent_response = last_message["content"]
117+
118+
# Check if any tools were used (image generation)
119+
if "intermediate_steps" in response and response["intermediate_steps"]:
120+
for step in response["intermediate_steps"]:
121+
if len(step) >= 2 and "generate_image" in str(step[0]):
122+
# Extract image data from the tool result
123+
tool_result = step[1]
124+
if "Image ID:" in tool_result:
125+
# Parse the image ID and title from the response
126+
image_id_match = re.search(r"Image ID: ([a-f0-9-]+)", tool_result)
127+
title_match = re.search(r"Title: (.+?)(?:\n|$)", tool_result)
128+
129+
if image_id_match:
130+
image_id = image_id_match.group(1)
131+
title = title_match.group(1) if title_match else "Generated Image"
132+
133+
# Get metadata from S3
134+
import boto3
135+
136+
s3_client = boto3.client(
137+
"s3",
138+
region_name=os.environ.get("AWS_REGION", "us-east-1"),
139+
aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"),
140+
aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"),
141+
)
142+
143+
bucket_name = os.environ.get("AWS_S3_BUCKET_NAME")
144+
if bucket_name:
145+
try:
146+
# Get metadata from S3
147+
metadata_response = s3_client.head_object(Bucket=bucket_name, Key=f"users/{user_id}/images/{image_id}")
148+
metadata = metadata_response.get("Metadata", {})
149+
150+
# Generate presigned URL
151+
presigned_url = s3_client.generate_presigned_url(
152+
"get_object",
153+
Params={
154+
"Bucket": bucket_name,
155+
"Key": f"users/{user_id}/images/{image_id}",
156+
},
157+
ExpiresIn=7200, # 2 hours
158+
)
159+
160+
generated_image_data = {
161+
"id": image_id,
162+
"url": presigned_url,
163+
"title": metadata.get("title", title),
164+
"description": f"AI-generated image: {metadata.get('generationPrompt', 'Based on your request')}",
165+
"timestamp": metadata.get("uploadedAt", datetime.now().isoformat()),
166+
"type": "generated",
167+
}
168+
except Exception as e:
169+
print(f"Error getting S3 metadata: {e}")
170+
# Fallback to basic data
171+
generated_image_data = {
172+
"id": image_id,
173+
"url": "", # Will be empty if we can't generate URL
174+
"title": title,
175+
"description": "AI-generated image",
176+
"timestamp": datetime.now().isoformat(),
177+
"type": "generated",
178+
}
179+
# Add error message to agent response
180+
agent_response += "\n\n⚠️ Note: I generated the image successfully,\
181+
but there was an issue retrieving it from the database."
182+
183+
return agent_response, generated_image_data
120184

121185

122186
if __name__ == "__main__":

api/llm/prompt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,9 @@
44
You are a helpful AI image editing assistant. You help users with image editing
55
tasks and provide guidance on how to modify their images.
66
7+
You can generate images using the generate_image tool. However, remember that
8+
you are only allowed to generate one image per user's request. You are NOT allowed
9+
to generate more than one image per user's request, no matter how many images the user
10+
wants to generate per request (e.g. generate 10 images for me based on this one image).
11+
712
"""

api/llm/tools.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def generate_image(
2121
prompt: str,
2222
user_id: str,
2323
image_url: str,
24+
title: str = "Generated Image",
2425
) -> str:
2526
"""
2627
Generate an image based on a prompt.
@@ -37,9 +38,7 @@ def generate_image(
3738
}
3839

3940
# Generate image using Replicate
40-
version = (
41-
"stability-ai/sdxl:" "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc"
42-
)
41+
version = "stability-ai/sdxl:" "7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc"
4342
output = replicate.run(
4443
version,
4544
input=input,
@@ -68,16 +67,18 @@ def generate_image(
6867
# Upload to S3
6968
try:
7069
s3_result = upload_generated_image_to_s3(
71-
image_data=image_data, image_id=image_id, user_id=user_id, prompt=prompt
70+
image_data=image_data,
71+
image_id=image_id,
72+
user_id=user_id,
73+
prompt=prompt,
74+
title=title,
7275
)
7376

7477
if s3_result["success"]:
7578
return f"Image generated successfully! User can find it his/her gallery. \
76-
Image ID: {image_id}"
79+
Image ID: {image_id}, Title: {title}"
7780
else:
78-
return (
79-
f"Image generated but failed to save: {s3_result.get('error', 'Unknown error')}"
80-
)
81+
return f"Image generated but failed to save: {s3_result.get('error', 'Unknown error')}"
8182

8283
except Exception as e:
8384
return f"Image generated but failed to save to storage: {str(e)}"

api/llm/utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from botocore.exceptions import ClientError
77

88

9-
def upload_generated_image_to_s3(
10-
image_data: bytes, image_id: str, user_id: str, prompt: str
11-
) -> Dict[str, Any]:
9+
def upload_generated_image_to_s3(image_data: bytes, image_id: str, user_id: str, prompt: str, title: str = "Generated Image") -> Dict[str, Any]:
1210
"""
1311
Upload a generated image to S3.
1412
@@ -17,9 +15,10 @@ def upload_generated_image_to_s3(
1715
image_id: Unique identifier for the image
1816
user_id: User identifier
1917
prompt: The prompt used to generate the image
18+
title: Custom title for the image
2019
2120
Returns:
22-
Dict with success status and URL or error message
21+
Dict with success status, URL, and metadata or error message
2322
"""
2423
try:
2524
# Initialize S3 client
@@ -44,7 +43,7 @@ def upload_generated_image_to_s3(
4443
Body=image_data,
4544
ContentType="image/png",
4645
Metadata={
47-
"title": "Generated Image", # TODO: add title to the image provided by agent
46+
"title": title,
4847
"imageId": image_id,
4948
"userId": user_id,
5049
"uploadedAt": datetime.now().isoformat(),
@@ -60,7 +59,11 @@ def upload_generated_image_to_s3(
6059
ExpiresIn=7200, # 2 hours
6160
)
6261

63-
return {"success": True, "url": presigned_url, "image_id": image_id}
62+
# Get metadata from S3
63+
metadata_response = s3_client.head_object(Bucket=bucket_name, Key=key)
64+
metadata = metadata_response.get("Metadata", {})
65+
66+
return {"success": True, "url": presigned_url, "image_id": image_id, "metadata": metadata}
6467

6568
except ClientError as e:
6669
return {"success": False, "error": str(e)}

api/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ where = ["."]
4040
include = ["llm*", "server*"]
4141

4242
[tool.black]
43-
line-length = 100
43+
line-length = 150
4444
target-version = ["py310"]
4545

4646
[tool.ruff]
47-
line-length = 100
47+
line-length = 150
4848
target-version = "py310"
4949
fix = true
5050
unsafe-fixes = true

api/server/main.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,19 @@ class ChatRequest(BaseModel):
1818
user_id: Optional[str] = None
1919

2020

21+
class GeneratedImage(BaseModel):
22+
id: str
23+
url: str
24+
title: str
25+
description: str
26+
timestamp: str
27+
type: str = "generated"
28+
29+
2130
class ChatResponse(BaseModel):
2231
response: str
2332
status: str = "success"
33+
generated_image: Optional[GeneratedImage] = None
2434

2535

2636
@app.get("/")
@@ -42,18 +52,24 @@ async def chat_endpoint(request: ChatRequest):
4252
request: ChatRequest containing message, selected_images, and user_id
4353
4454
Returns:
45-
ChatResponse with AI response and status.
55+
ChatResponse with AI response, status, and optional generated image metadata.
4656
"""
4757
try:
4858
# Use the LLM agent to get a response
4959
user_id = request.user_id or "default"
50-
response = chat_with_agent(
60+
response, generated_image_data = chat_with_agent(
5161
message=request.message,
5262
user_id=user_id,
5363
selected_images=request.selected_images,
5464
)
5565

56-
return ChatResponse(response=response, status="success")
66+
# Create response with optional generated image
67+
chat_response = ChatResponse(response=response, status="success")
68+
69+
if generated_image_data:
70+
chat_response.generated_image = GeneratedImage(**generated_image_data)
71+
72+
return chat_response
5773

5874
except Exception as e:
5975
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

src/app/page.tsx

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ export default function Home() {
122122

123123
const aiMessage = createMessage(apiResponse.response, "agent");
124124
addMessage(aiMessage);
125+
126+
// Handle generated image if present
127+
if (apiResponse.generated_image) {
128+
const generatedImage: ImageItem = {
129+
id: apiResponse.generated_image.id,
130+
url: apiResponse.generated_image.url,
131+
title: apiResponse.generated_image.title,
132+
description: apiResponse.generated_image.description,
133+
timestamp: new Date(apiResponse.generated_image.timestamp),
134+
type: "generated" as const,
135+
};
136+
137+
// Add to images list
138+
setImages((prev) => [...prev, generatedImage]);
139+
140+
// Scroll to show the new image
141+
setTimeout(scrollToRight, 100);
142+
}
125143
} catch (error) {
126144
console.error("Error getting AI response:", error);
127145
const fallbackMessage = createMessage(

src/lib/actions.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import {
66
GetObjectCommand,
77
} from "@aws-sdk/client-s3";
88
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
9+
import type { GeneratedImage } from "./types";
910

1011
interface ChatRequest {
1112
message: string;
@@ -23,6 +24,7 @@ interface ChatRequest {
2324
interface ChatResponse {
2425
response: string;
2526
status: string;
27+
generated_image?: GeneratedImage;
2628
}
2729

2830
interface UploadResponse {

src/lib/types.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,12 @@ export interface ImageItem {
1313
timestamp: Date;
1414
type: "uploaded" | "generated" | "sample";
1515
}
16+
17+
export interface GeneratedImage {
18+
id: string;
19+
url: string;
20+
title: string;
21+
description: string;
22+
timestamp: string;
23+
type: string;
24+
}

0 commit comments

Comments
 (0)