Skip to content

Commit 97aabd8

Browse files
committed
Merge remote-tracking branch 'jaychao/pr-jay-24' into merge_pr
2 parents 54cadeb + 98ed87a commit 97aabd8

File tree

18 files changed

+8385
-4622
lines changed

18 files changed

+8385
-4622
lines changed

.github/workflows/ci-quick.yml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ env:
1818
NODE_VERSION: '20.x'
1919
PYTHON_VERSION: '3.10'
2020

21+
permissions:
22+
contents: read
23+
pull-requests: write
24+
issues: write
25+
2126
jobs:
2227
# ============================================
2328
# Quick validation (lint, type check, fast tests)
@@ -106,8 +111,6 @@ jobs:
106111
"RESILIENTDB_GRAPHQL_URI=https://cloud.resilientdb.com/graphql" \
107112
"JWT_SECRET=test-secret-key-do-not-use-in-production" > .env
108113
109-
- name: Run backend unit tests only (fast)
110-
# wait for mongodb to be ready on localhost:27017
111114
- name: Wait for MongoDB
112115
run: |
113116
for i in {1..30}; do
@@ -136,12 +139,14 @@ jobs:
136139
--ci
137140
138141
- name: Build check (frontend)
142+
env:
143+
CI: false
139144
run: |
140145
cd frontend
141146
npm run build
142147
143148
- name: Comment PR with quick results
144-
if: github.event_name == 'pull_request' && always()
149+
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.fork == false && always()
145150
uses: actions/github-script@v7
146151
with:
147152
script: |

README.md

Lines changed: 118 additions & 1 deletion
Large diffs are not rendered by default.

backend/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def filter(self, record):
5656
from routes.frontend import frontend_bp
5757
from routes.analytics import analytics_bp
5858
from routes.export import export_bp
59+
from routes.ai_assistant import ai_assistant_bp
5960
from services.db import redis_client
6061
from services.canvas_counter import get_canvas_draw_count
6162
from services.graphql_service import commit_transaction_via_graphql
@@ -215,6 +216,7 @@ def handle_all_exceptions(e):
215216
app.register_blueprint(submit_room_line_bp)
216217
app.register_blueprint(admin_bp)
217218
app.register_blueprint(export_bp)
219+
app.register_blueprint(ai_assistant_bp)
218220

219221
# Register versioned API v1 blueprints for external applications
220222
from api_v1.auth import auth_v1_bp

backend/routes/ai_assistant.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from flask import Blueprint, request, jsonify
2+
# Import style transfer function as well
3+
from services.llm_service import (
4+
prompt_to_drawings,
5+
complete_shape_from_canvas,
6+
beautify_canvas_state,
7+
style_transfer_canvas,
8+
)
9+
from services.llm_service import recognize_objects_in_box
10+
# from services.image_generation_service import (
11+
# text_to_image as img_text_to_image,
12+
# )
13+
import logging
14+
import base64
15+
import io
16+
17+
ai_assistant_bp = Blueprint('ai_assistant', __name__)
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@ai_assistant_bp.route('/api/ai_assistant/drawing', methods=['POST'])
22+
def text_to_drawings():
23+
"""
24+
Body: { "prompt": "<natural language description>", canvasState: {json object} }
25+
Returns: Parsed drawing JSON (shape/color/size/position/...) or an error payload.
26+
"""
27+
try:
28+
payload = request.get_json(silent=True) or {}
29+
prompt = payload.get("prompt")
30+
canvasState = payload.get("canvasState") or {}
31+
32+
if not isinstance(prompt, str) or not prompt.strip():
33+
return jsonify({"error": "bad_request", "detail": "Missing or invalid 'prompt' (string)."}), 400
34+
35+
logger.info("AI drawing requested")
36+
result = prompt_to_drawings(prompt.strip(), canvasState)
37+
38+
print(f"\n\nModel result: {result}\n\n")
39+
40+
# If services returned an error, surface it with 502 (bad upstream)
41+
if isinstance(result, dict) and "error" in result:
42+
logger.warning("AI drawing failed: %s", result)
43+
return jsonify({"error": "upstream_model_error", "detail": result}), 502
44+
45+
return jsonify(result), 200
46+
except Exception as e:
47+
logger.exception("Unhandled error in /drawing")
48+
return jsonify({"error": "server_error", "detail": str(e)}), 500
49+
50+
51+
@ai_assistant_bp.route('/api/ai_assistant/complete', methods=['POST'])
52+
def shape_completion():
53+
"""
54+
Body: { "canvasState": { ... } }
55+
Returns: { complete, confidence, object{ color, lineWidth, pathData{...} } } or an error payload.
56+
"""
57+
try:
58+
payload = request.get_json(silent=True) or {}
59+
canvas_state = payload.get("canvasState")
60+
if not isinstance(canvas_state, dict):
61+
return jsonify({"error": "bad_request", "detail": "Missing or invalid 'canvas_state' (object)."}), 400
62+
63+
logger.info("AI shape completion requested")
64+
suggestion = complete_shape_from_canvas(canvas_state)
65+
66+
if not isinstance(canvas_state, dict):
67+
return jsonify({
68+
"error": "bad_request",
69+
"detail": "Missing or invalid 'canvasState' (object)."
70+
}), 400
71+
72+
return jsonify(suggestion), 200
73+
except Exception as e:
74+
logger.exception("Unhandled error in /complete")
75+
return jsonify({"error": "server_error", "detail": str(e)}), 500
76+
77+
78+
@ai_assistant_bp.route('/api/ai_assistant/image', methods=['POST'])
79+
def text_to_image():
80+
"""
81+
TODO: To be implemented
82+
Body: { "prompt": "<string>", "width"?: int, "height"?: int, "style"?: str }
83+
Returns: { "imageDataUrl": "data:image/png;base64,..." }
84+
"""
85+
try:
86+
payload = request.get_json(silent=True) or {}
87+
prompt = payload.get("prompt", "")
88+
width = payload.get("width") or 512
89+
height = payload.get("height") or 512
90+
style = payload.get("style") or "default"
91+
92+
if not isinstance(prompt, str) or not prompt.strip():
93+
return jsonify({
94+
"error": "bad_request",
95+
"detail": "Missing or invalid 'prompt' (string)."
96+
}), 400
97+
98+
logger.info("AI text-to-image requested")
99+
100+
# Try to generate via image_generation_service
101+
try:
102+
from services.image_generation_service import text_to_image as img_text_to_image
103+
pil_image = img_text_to_image(prompt.strip(), width=width, height=height, style=style)
104+
except Exception as e:
105+
logger.exception("Image generation failed: %s", e)
106+
return jsonify({"error": "image_generation_failed", "detail": str(e)}), 502
107+
108+
buf = io.BytesIO()
109+
pil_image.save(buf, format="PNG")
110+
buf.seek(0)
111+
encoded = base64.b64encode(buf.read()).decode("utf-8")
112+
data_url = f"data:image/png;base64,{encoded}"
113+
114+
return jsonify({"imageDataUrl": data_url}), 200
115+
116+
except Exception as e:
117+
logger.exception("Unhandled error in /image")
118+
return jsonify({"error": "server_error", "detail": str(e)}), 500
119+
120+
121+
@ai_assistant_bp.route("/api/ai_assistant/beautify", methods=["POST"])
122+
def beautify_sketch():
123+
try:
124+
payload = request.get_json(silent=True) or {}
125+
canvas_state = payload.get("canvasState")
126+
127+
if not isinstance(canvas_state, dict):
128+
return jsonify({
129+
"error": "bad_request",
130+
"detail": "Missing or invalid 'canvasState' (object)."
131+
}), 400
132+
133+
result = beautify_canvas_state(canvas_state)
134+
# print("\n\ncanvas_state!!!", canvas_state, "\n\n")
135+
# print("\n\nResult!!!", result, "\n\n")
136+
137+
if not isinstance(result, dict) or "objects" not in result:
138+
logger.warning("Beautify returned invalid payload: %r", result)
139+
return jsonify({
140+
"error": "upstream_model_error",
141+
"detail": "Beautify model returned invalid payload."
142+
}), 502
143+
144+
return jsonify(result), 200
145+
146+
except Exception as e:
147+
logger.exception("Unhandled error in /beautify")
148+
return jsonify({"error": "server_error", "detail": str(e)}), 500
149+
150+
151+
@ai_assistant_bp.route('/api/ai_assistant/style', methods=['POST'])
152+
def style_transfer():
153+
"""
154+
Body: { "canvasState": {...}, "stylePrompt": "<string describing style e.g. 'Van Gogh oil painting'" }
155+
Returns: { "objects": [...] } or error
156+
"""
157+
try:
158+
payload = request.get_json(silent=True) or {}
159+
canvas_state = payload.get('canvasState')
160+
style_prompt = payload.get('stylePrompt')
161+
162+
if not isinstance(canvas_state, dict):
163+
return jsonify({"error": "bad_request", "detail": "Missing or invalid 'canvasState' (object)."}), 400
164+
if not isinstance(style_prompt, str) or not style_prompt.strip():
165+
return jsonify({"error": "bad_request", "detail": "Missing or invalid 'stylePrompt' (string)."}), 400
166+
167+
logger.info('AI style transfer requested')
168+
result = style_transfer_canvas(canvas_state, style_prompt.strip())
169+
170+
# If the model returned an error payload, log it and return a safe
171+
# fallback: the original canvas objects so the client can continue.
172+
if isinstance(result, dict) and "error" in result:
173+
logger.warning('Style transfer model error, falling back to original canvas: %s', result)
174+
original_objects = canvas_state.get("objects") or canvas_state.get("drawings") or []
175+
return jsonify({"objects": original_objects}), 200
176+
177+
# Normal successful response
178+
return jsonify(result), 200
179+
except Exception as e:
180+
logger.exception('Unhandled error in /style')
181+
return jsonify({"error": "server_error", "detail": str(e)}), 500
182+
183+
184+
@ai_assistant_bp.route('/api/ai_assistant/recognize', methods=['POST'])
185+
def recognize():
186+
"""
187+
Body: { "canvasObjects": [...], "box": { x,y,width,height }, "bounds": { width, height } }
188+
Returns: { label, confidence, explanation } or error
189+
"""
190+
try:
191+
payload = request.get_json(silent=True) or {}
192+
canvas_objects = payload.get('canvasObjects') or payload.get('objects') or []
193+
box = payload.get('box') or {}
194+
bounds = payload.get('bounds') or {}
195+
196+
if not isinstance(canvas_objects, list):
197+
return jsonify({"error": "bad_request", "detail": "'canvasObjects' must be a list."}), 400
198+
199+
result = recognize_objects_in_box(canvas_objects, box, bounds)
200+
201+
if isinstance(result, dict) and 'error' in result:
202+
logger.warning('Recognition upstream error: %s', result)
203+
return jsonify({"error": "upstream_model_error", "detail": result}), 502
204+
205+
return jsonify(result), 200
206+
except Exception as e:
207+
logger.exception('Unhandled error in /recognize')
208+
return jsonify({"error": "server_error", "detail": str(e)}), 500
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
from config import OPENAI_API_KEY
3+
4+
5+
def text_to_image(prompt: str, width: int = 512, height: int = 512, style: str = "default"):
6+
"""
7+
Generate an image for the given prompt. Returns a PIL.Image object on success.
8+
This function attempts to use the OpenAI Images API if available; otherwise,
9+
it falls back to generating a very small placeholder image so the endpoint
10+
returns something useful for UI development.
11+
"""
12+
# Try OpenAI Images (if package available)
13+
try:
14+
from openai import OpenAI
15+
from PIL import Image
16+
import base64
17+
import io
18+
19+
client = OpenAI(api_key=OPENAI_API_KEY)
20+
21+
# Attempt to use images.generate if available on the client
22+
try:
23+
resp = client.images.generate(
24+
model="gpt-image-1",
25+
prompt=prompt,
26+
size=f"{width}x{height}"
27+
)
28+
# The response may contain base64 data depending on the SDK version
29+
b64 = None
30+
if isinstance(resp, dict) and resp.get("data") and isinstance(resp["data"], list):
31+
item = resp["data"][0]
32+
if isinstance(item, dict) and item.get("b64_json"):
33+
b64 = item.get("b64_json")
34+
if b64:
35+
img_bytes = base64.b64decode(b64)
36+
return Image.open(io.BytesIO(img_bytes))
37+
except Exception:
38+
# Fall back to other approaches below
39+
pass
40+
41+
except Exception:
42+
# openai or PIL not available - fall through to placeholder
43+
pass
44+
45+
# Placeholder fallback: create a simple blank image (Pillow may be missing)
46+
try:
47+
from PIL import Image, ImageDraw, ImageFont
48+
img = Image.new("RGBA", (width, height), (255, 255, 255, 255))
49+
draw = ImageDraw.Draw(img)
50+
# Draw a small placeholder label in the center if fonts available
51+
try:
52+
f = ImageFont.load_default()
53+
text = "AI\nImage"
54+
w, h = draw.multiline_textsize(text, font=f)
55+
draw.multiline_text(((width - w) / 2, (height - h) / 2), text, fill=(120, 120, 120), font=f, align="center")
56+
except Exception:
57+
pass
58+
return img
59+
except Exception as e:
60+
raise RuntimeError("No image generation backend available: " + str(e))
61+

0 commit comments

Comments
 (0)