Skip to content

Commit 057b8d4

Browse files
committed
Update backend
1 parent e219534 commit 057b8d4

File tree

2 files changed

+81
-1
lines changed

2 files changed

+81
-1
lines changed

backend/main.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
DirectoryModel, DirectoryDetailResponse, RemoveDirectoryResponse, RemoveDirectoryRequest, CreateQueryRequest, \
2121
CreateQueryResponse, GeneratorInfo, SearchLogsResponse, QueryLogEntry, \
2222
ServiceStatusResponse, ServiceLogResponse, SearchResponse, SearchRequest, UpdateDirectoryResponse, \
23-
UpdateDirectoryRequest
23+
UpdateDirectoryRequest, GeneratePoolRequest, GeneratePoolResponse, GuideImageData, EmbeddingData
2424
from indexing import image_indexing_service
2525
from utils import aggregate_rankings, pil_image_to_base64, Timer
2626
from version import VERSION as BACKEND_VERSION
@@ -310,6 +310,61 @@ async def service_log():
310310
return ServiceLogResponse(log="Service log not implemented yet.")
311311

312312

313+
@app.post("/variance-analysis/generate-pool", response_model=GeneratePoolResponse)
314+
async def generate_pool(request: GeneratePoolRequest):
315+
"""
316+
Generate a pool of guide images for variance analysis.
317+
This endpoint generates M_pool guide images and computes embeddings for all embedders.
318+
"""
319+
# Prepare generation config - we need to generate pool_size images
320+
generation_request = request.generation_config.model_dump()
321+
for engine in generation_request["engines"]:
322+
engine["prompt"] = request.query
323+
324+
# Adjust to generate the requested pool size
325+
# We'll use multiple engines if needed to reach pool_size
326+
original_num_images = generation_request.get("num_images", 1)
327+
original_num_engines = generation_request.get("num_engines_to_use", 1)
328+
329+
# Calculate how many images per engine we need
330+
images_per_engine = max(1, request.pool_size // original_num_engines)
331+
generation_request["num_images"] = images_per_engine
332+
333+
# Generate images
334+
generated_images = image_generator.generate(generation_request)
335+
336+
# Limit to pool_size if we generated more
337+
generated_images = generated_images[:request.pool_size]
338+
339+
# Get all embedders
340+
embedders = embedder_manager.get_image_embedders()
341+
embedder_names = list(embedders.keys())
342+
343+
# Compute embeddings for all guide images using all embedders
344+
guide_images_data = []
345+
for idx, (image, engine_name) in enumerate(generated_images):
346+
embeddings_data = []
347+
for embedder_name, embedder in embedders.items():
348+
embedding = embedder.embed(image)
349+
embeddings_data.append(EmbeddingData(
350+
embedder_name=embedder_name,
351+
embedding=embedding.tolist() if hasattr(embedding, 'tolist') else embedding
352+
))
353+
354+
guide_images_data.append(GuideImageData(
355+
image_index=idx,
356+
base64_image=pil_image_to_base64(image),
357+
embeddings=embeddings_data
358+
))
359+
360+
return GeneratePoolResponse(
361+
query=request.query,
362+
pool_size=len(guide_images_data),
363+
guide_images=guide_images_data,
364+
embedder_names=embedder_names
365+
)
366+
367+
313368
from routes.gallery import router as gallery_router
314369

315370
app.include_router(gallery_router)

backend/models/schemas.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,28 @@ class ServiceStatusResponse(BaseModel):
124124

125125
class ServiceLogResponse(BaseModel):
126126
log: str
127+
128+
129+
# Variance Analysis Schemas
130+
class GeneratePoolRequest(BaseModel):
131+
query: str = Field(..., description="Query string")
132+
pool_size: int = Field(20, description="Number of guide images to generate (M_pool)")
133+
generation_config: GenerationConfig = Field(..., description="Configuration for image generation")
134+
135+
136+
class EmbeddingData(BaseModel):
137+
embedder_name: str
138+
embedding: List[float]
139+
140+
141+
class GuideImageData(BaseModel):
142+
image_index: int
143+
base64_image: str
144+
embeddings: List[EmbeddingData]
145+
146+
147+
class GeneratePoolResponse(BaseModel):
148+
query: str
149+
pool_size: int
150+
guide_images: List[GuideImageData]
151+
embedder_names: List[str]

0 commit comments

Comments
 (0)