|
20 | 20 | DirectoryModel, DirectoryDetailResponse, RemoveDirectoryResponse, RemoveDirectoryRequest, CreateQueryRequest, \ |
21 | 21 | CreateQueryResponse, GeneratorInfo, SearchLogsResponse, QueryLogEntry, \ |
22 | 22 | ServiceStatusResponse, ServiceLogResponse, SearchResponse, SearchRequest, UpdateDirectoryResponse, \ |
23 | | - UpdateDirectoryRequest |
| 23 | + UpdateDirectoryRequest, GeneratePoolRequest, GeneratePoolResponse, GuideImageData, EmbeddingData |
24 | 24 | from indexing import image_indexing_service |
25 | 25 | from utils import aggregate_rankings, pil_image_to_base64, Timer |
26 | 26 | from version import VERSION as BACKEND_VERSION |
@@ -310,6 +310,61 @@ async def service_log(): |
310 | 310 | return ServiceLogResponse(log="Service log not implemented yet.") |
311 | 311 |
|
312 | 312 |
|
| 313 | +@app.post("/variance-analysis/generate-pool", response_model=GeneratePoolResponse) |
| 314 | +async def generate_pool(request: GeneratePoolRequest): |
| 315 | + """ |
| 316 | + Generate a pool of guide images for variance analysis. |
| 317 | + This endpoint generates M_pool guide images and computes embeddings for all embedders. |
| 318 | + """ |
| 319 | + # Prepare generation config - we need to generate pool_size images |
| 320 | + generation_request = request.generation_config.model_dump() |
| 321 | + for engine in generation_request["engines"]: |
| 322 | + engine["prompt"] = request.query |
| 323 | + |
| 324 | + # Adjust to generate the requested pool size |
| 325 | + # We'll use multiple engines if needed to reach pool_size |
| 326 | + original_num_images = generation_request.get("num_images", 1) |
| 327 | + original_num_engines = generation_request.get("num_engines_to_use", 1) |
| 328 | + |
| 329 | + # Calculate how many images per engine we need |
| 330 | + images_per_engine = max(1, request.pool_size // original_num_engines) |
| 331 | + generation_request["num_images"] = images_per_engine |
| 332 | + |
| 333 | + # Generate images |
| 334 | + generated_images = image_generator.generate(generation_request) |
| 335 | + |
| 336 | + # Limit to pool_size if we generated more |
| 337 | + generated_images = generated_images[:request.pool_size] |
| 338 | + |
| 339 | + # Get all embedders |
| 340 | + embedders = embedder_manager.get_image_embedders() |
| 341 | + embedder_names = list(embedders.keys()) |
| 342 | + |
| 343 | + # Compute embeddings for all guide images using all embedders |
| 344 | + guide_images_data = [] |
| 345 | + for idx, (image, engine_name) in enumerate(generated_images): |
| 346 | + embeddings_data = [] |
| 347 | + for embedder_name, embedder in embedders.items(): |
| 348 | + embedding = embedder.embed(image) |
| 349 | + embeddings_data.append(EmbeddingData( |
| 350 | + embedder_name=embedder_name, |
| 351 | + embedding=embedding.tolist() if hasattr(embedding, 'tolist') else embedding |
| 352 | + )) |
| 353 | + |
| 354 | + guide_images_data.append(GuideImageData( |
| 355 | + image_index=idx, |
| 356 | + base64_image=pil_image_to_base64(image), |
| 357 | + embeddings=embeddings_data |
| 358 | + )) |
| 359 | + |
| 360 | + return GeneratePoolResponse( |
| 361 | + query=request.query, |
| 362 | + pool_size=len(guide_images_data), |
| 363 | + guide_images=guide_images_data, |
| 364 | + embedder_names=embedder_names |
| 365 | + ) |
| 366 | + |
| 367 | + |
313 | 368 | from routes.gallery import router as gallery_router |
314 | 369 |
|
315 | 370 | app.include_router(gallery_router) |
0 commit comments