|
23 | 23 | ImageDeleteRequest, |
24 | 24 | ImageDeleteResponse, |
25 | 25 | ImageAnalyzeRequest, |
| 26 | + ImageAnalyzeCustomRequest, |
26 | 27 | ImageAnalyzeResponse, |
27 | 28 | ImagePromptEnhancementRequest, |
28 | 29 | ImagePromptEnhancementResponse, |
|
34 | 35 | ImageSaveResponse, |
35 | 36 | TokenUsage, |
36 | 37 | InputTokensDetails, |
| 38 | + ImageGenerateWithAnalysisRequest, |
37 | 39 | ) |
38 | 40 | from backend.models.gallery import MediaType |
39 | 41 | from backend.core import llm_client, dalle_client, image_sas_token |
@@ -879,13 +881,75 @@ async def save_generated_images( |
879 | 881 | analysis_results=analysis_results if analyzed else None, |
880 | 882 | analyzed=analyzed, |
881 | 883 | ) |
882 | | - |
883 | 884 | except Exception as e: |
884 | 885 | logger.error(f"Error saving generated images: {str(e)}", exc_info=True) |
885 | 886 | raise HTTPException( |
886 | 887 | status_code=500, detail=f"Error saving generated images: {str(e)}" |
887 | 888 | ) |
888 | 889 |
|
| 890 | +@router.post("/generate-with-analysis", response_model=ImageSaveResponse) |
| 891 | +async def generate_image_with_analysis( |
| 892 | + req: ImageGenerateWithAnalysisRequest, |
| 893 | + azure_storage_service: AzureBlobStorageService = Depends( |
| 894 | + lambda: AzureBlobStorageService() |
| 895 | + ), |
| 896 | + cosmos_service: Optional[CosmosDBService] = Depends(get_cosmos_service), |
| 897 | +): |
| 898 | + """ |
| 899 | + Generate image(s), then save to storage and optionally analyze in one call. |
| 900 | + Reuses existing generation and save logic to avoid duplication. |
| 901 | + """ |
| 902 | + try: |
| 903 | + # Build generation parameters (same as /images/generate) |
| 904 | + params = { |
| 905 | + "prompt": req.prompt, |
| 906 | + "model": req.model, |
| 907 | + "n": req.n, |
| 908 | + "size": req.size, |
| 909 | + } |
| 910 | + |
| 911 | + if req.model == "gpt-image-1": |
| 912 | + if req.quality: |
| 913 | + params["quality"] = req.quality |
| 914 | + params["background"] = req.background |
| 915 | + if req.output_format and req.output_format != "png": |
| 916 | + params["output_format"] = req.output_format |
| 917 | + if req.output_format in ["webp", "jpeg"] and req.output_compression != 100: |
| 918 | + params["output_compression"] = req.output_compression |
| 919 | + if req.moderation and req.moderation != "auto": |
| 920 | + params["moderation"] = req.moderation |
| 921 | + if req.user: |
| 922 | + params["user"] = req.user |
| 923 | + |
| 924 | + # Generate images via model client |
| 925 | + response = dalle_client.generate_image(**params) |
| 926 | + |
| 927 | + # Construct generation response to feed into existing /save logic |
| 928 | + gen_response = ImageGenerationResponse( |
| 929 | + success=True, |
| 930 | + message="Image(s) generated successfully", |
| 931 | + imgen_model_response=response, |
| 932 | + token_usage=None, |
| 933 | + ) |
| 934 | + |
| 935 | + save_request = ImageSaveRequest( |
| 936 | + generation_response=gen_response, |
| 937 | + prompt=req.prompt, |
| 938 | + model=req.model, |
| 939 | + size=req.size, |
| 940 | + background=req.background, |
| 941 | + output_format=req.output_format, |
| 942 | + save_all=req.save_all, |
| 943 | + folder_path=req.folder_path, |
| 944 | + analyze=req.analyze, |
| 945 | + ) |
| 946 | + |
| 947 | + # Call existing save endpoint function directly with explicit deps |
| 948 | + return await save_generated_images(save_request, azure_storage_service, cosmos_service) |
| 949 | + except Exception as e: |
| 950 | + logger.error(f"Error in /images/generate-with-analysis: {str(e)}", exc_info=True) |
| 951 | + raise HTTPException(status_code=500, detail=str(e)) |
| 952 | + |
889 | 953 |
|
890 | 954 | @router.post("/list", response_model=ImageListResponse) |
891 | 955 | async def list_images(request: ImageListRequest): |
@@ -1054,6 +1118,164 @@ def analyze_image(req: ImageAnalyzeRequest): |
1054 | 1118 | status_code=500, detail=f"Error analyzing image: {str(e)}") |
1055 | 1119 |
|
1056 | 1120 |
|
| 1121 | +@router.post("/analyze-custom", response_model=ImageAnalyzeResponse) |
| 1122 | +def analyze_image_custom(req: ImageAnalyzeCustomRequest): |
| 1123 | + """ |
| 1124 | + Analyze an image using a custom prompt while maintaining the same response structure. |
| 1125 | + |
| 1126 | + Args: |
| 1127 | + image_path: path on Azure Blob Storage. Supports a full URL with or without a SAS token. |
| 1128 | + OR |
| 1129 | + base64_image: Base64-encoded image data to analyze directly. |
| 1130 | + custom_prompt: Custom instructions for the analysis. |
| 1131 | + |
| 1132 | + Returns: |
| 1133 | + Response containing description, products, tags, and feedback based on custom prompt. |
| 1134 | + """ |
| 1135 | + try: |
| 1136 | + # Initialize image_content |
| 1137 | + image_content = None |
| 1138 | + |
| 1139 | + # Option 1: Process from URL/path |
| 1140 | + if req.image_path: |
| 1141 | + file_path = req.image_path |
| 1142 | + |
| 1143 | + # check if the path is a valid Azure blob storage path |
| 1144 | + pattern = r"^https://[a-z0-9]+\.blob\.core\.windows\.net/[a-z0-9]+/.+" |
| 1145 | + match = re.match(pattern, file_path) |
| 1146 | + |
| 1147 | + if not match: |
| 1148 | + raise ValueError("Invalid Azure blob storage path") |
| 1149 | + else: |
| 1150 | + # check if the path contains a SAS token |
| 1151 | + if "?" not in file_path: |
| 1152 | + file_path += f"?{image_sas_token}" |
| 1153 | + |
| 1154 | + # Download the image from the URL |
| 1155 | + response = requests.get(file_path, timeout=30) |
| 1156 | + if response.status_code != 200: |
| 1157 | + raise HTTPException( |
| 1158 | + status_code=response.status_code, |
| 1159 | + detail=f"Failed to download image: HTTP {response.status_code}", |
| 1160 | + ) |
| 1161 | + |
| 1162 | + # Get image content from response |
| 1163 | + image_content = response.content |
| 1164 | + |
| 1165 | + # Option 2: Process from base64 string |
| 1166 | + elif req.base64_image: |
| 1167 | + try: |
| 1168 | + # Decode base64 to binary |
| 1169 | + image_content = base64.b64decode(req.base64_image) |
| 1170 | + except Exception as e: |
| 1171 | + raise HTTPException( |
| 1172 | + status_code=400, detail=f"Invalid base64 image data: {str(e)}" |
| 1173 | + ) |
| 1174 | + |
| 1175 | + # Process the image with PIL to handle transparency properly (same as regular analyze) |
| 1176 | + try: |
| 1177 | + # Open the image with PIL |
| 1178 | + with Image.open(io.BytesIO(image_content)) as img: |
| 1179 | + # Check if it's a transparent PNG |
| 1180 | + has_transparency = img.mode == "RGBA" and "A" in img.getbands() |
| 1181 | + |
| 1182 | + if has_transparency: |
| 1183 | + # Create a white background |
| 1184 | + background = Image.new("RGBA", img.size, (255, 255, 255, 255)) |
| 1185 | + # Paste the image on the background |
| 1186 | + background.paste(img, (0, 0), img) |
| 1187 | + # Convert to RGB (remove alpha channel) |
| 1188 | + background = background.convert("RGB") |
| 1189 | + |
| 1190 | + # Save to bytes |
| 1191 | + img_byte_arr = io.BytesIO() |
| 1192 | + background.save(img_byte_arr, format="JPEG") |
| 1193 | + img_byte_arr.seek(0) |
| 1194 | + image_content = img_byte_arr.getvalue() |
| 1195 | + |
| 1196 | + # Also try to resize if the image is very large (LLM models have token limits) |
| 1197 | + width, height = img.size |
| 1198 | + if width > 1500 or height > 1500: |
| 1199 | + # Calculate new dimensions |
| 1200 | + max_dimension = 1500 |
| 1201 | + if width > height: |
| 1202 | + new_width = max_dimension |
| 1203 | + new_height = int(height * (max_dimension / width)) |
| 1204 | + else: |
| 1205 | + new_height = max_dimension |
| 1206 | + new_width = int(width * (max_dimension / height)) |
| 1207 | + |
| 1208 | + # Resize the image |
| 1209 | + if has_transparency: |
| 1210 | + # We already have the background image from above |
| 1211 | + resized_img = background.resize((new_width, new_height)) |
| 1212 | + else: |
| 1213 | + resized_img = img.resize((new_width, new_height)) |
| 1214 | + |
| 1215 | + # Save to bytes |
| 1216 | + img_byte_arr = io.BytesIO() |
| 1217 | + resized_img.save( |
| 1218 | + img_byte_arr, |
| 1219 | + format="JPEG" if resized_img.mode == "RGB" else "PNG", |
| 1220 | + ) |
| 1221 | + img_byte_arr.seek(0) |
| 1222 | + image_content = img_byte_arr.getvalue() |
| 1223 | + except Exception as img_error: |
| 1224 | + logger.error(f"Error processing image with PIL: {str(img_error)}") |
| 1225 | + # If PIL processing fails, continue with the original image |
| 1226 | + |
| 1227 | + # Convert to base64 |
| 1228 | + image_base64 = base64.b64encode(image_content).decode("utf-8") |
| 1229 | + # Remove data URL prefix if present |
| 1230 | + image_base64 = re.sub(r"^data:image/.+;base64,", "", image_base64) |
| 1231 | + |
| 1232 | + # Create custom system message using the provided custom prompt |
| 1233 | + custom_prompt = req.custom_prompt |
| 1234 | + if not custom_prompt or not custom_prompt.strip(): |
| 1235 | + raise HTTPException( |
| 1236 | + status_code=400, detail="Custom prompt is required for custom analysis" |
| 1237 | + ) |
| 1238 | + |
| 1239 | + custom_system_message = f"""You are an expert in analyzing images. |
| 1240 | +You are provided with a single image to analyze in detail. |
| 1241 | +
|
| 1242 | +CUSTOM ANALYSIS INSTRUCTIONS: |
| 1243 | +{custom_prompt} |
| 1244 | +
|
| 1245 | +Your task is to extract the following based on the custom instructions above: |
| 1246 | +1. detailed description based on the custom requirements above |
| 1247 | +2. named brands or named products visible in the image |
| 1248 | +3. metadata tags useful for organizing and searching content. Limit to the 5 most relevant tags. |
| 1249 | +4. feedback to improve the image based on the custom criteria above |
| 1250 | +
|
| 1251 | +Return the result as a valid JSON object: |
| 1252 | +{{ |
| 1253 | + "description": "<Custom analysis based on provided instructions>", |
| 1254 | + "products": "<named brands / named products identified>", |
| 1255 | + "tags": ["<tag1>", "<tag2>", "<tag3>", "<tag4>", "<tag5>"], |
| 1256 | + "feedback": "<Feedback based on custom criteria>" |
| 1257 | +}} |
| 1258 | +""" |
| 1259 | + |
| 1260 | + # analyze the image using the LLM with custom prompt |
| 1261 | + image_analyzer = ImageAnalyzer(llm_client, settings.LLM_DEPLOYMENT) |
| 1262 | + insights = image_analyzer.image_chat(image_base64, custom_system_message) |
| 1263 | + |
| 1264 | + description = insights.get("description") |
| 1265 | + products = insights.get("products") |
| 1266 | + tags = insights.get("tags") |
| 1267 | + feedback = insights.get("feedback") |
| 1268 | + |
| 1269 | + return ImageAnalyzeResponse( |
| 1270 | + description=description, products=products, tags=tags, feedback=feedback |
| 1271 | + ) |
| 1272 | + |
| 1273 | + except Exception as e: |
| 1274 | + logger.error(f"Error analyzing image with custom prompt: {str(e)}", exc_info=True) |
| 1275 | + raise HTTPException( |
| 1276 | + status_code=500, detail=f"Error analyzing image with custom prompt: {str(e)}") |
| 1277 | + |
| 1278 | + |
1057 | 1279 | @router.post("/prompt/enhance", response_model=ImagePromptEnhancementResponse) |
1058 | 1280 | def enhance_image_prompt(req: ImagePromptEnhancementRequest): |
1059 | 1281 | """ |
|
0 commit comments