-
Notifications
You must be signed in to change notification settings - Fork 524
Fix 904 celebrity tagging #990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,97 @@ | ||||||||||||||||||||||||||||||||
| import pickle | ||||||||||||||||||||||||||||||||
| import numpy as np | ||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||
| from app.logging.setup_logging import get_logger | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| logger = get_logger(__name__) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| class CelebrityMatcher: | ||||||||||||||||||||||||||||||||
| def __init__(self, encodings_path=None): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Initializes the CelebrityMatcher by loading encodings and optimizing them into a matrix. | ||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||
| encodings_path (str): Path to the pickle file containing celebrity encodings. | ||||||||||||||||||||||||||||||||
| If None, defaults to 'celebrity_encodings.pkl' in this directory. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| if encodings_path is None: | ||||||||||||||||||||||||||||||||
| encodings_path = os.path.join(os.path.dirname(__file__), 'celebrity_encodings.pkl') | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| self.names = [] | ||||||||||||||||||||||||||||||||
| self.embeddings_matrix = None | ||||||||||||||||||||||||||||||||
| self._load_and_optimize(encodings_path) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _load_and_optimize(self, path): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Loads the pickle file and converts the dictionary of lists into a NumPy matrix. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| # Resolve absolute path if relative path is provided | ||||||||||||||||||||||||||||||||
| if not os.path.isabs(path): | ||||||||||||||||||||||||||||||||
| if os.path.exists(path): | ||||||||||||||||||||||||||||||||
| pass | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| # Check relative to this file location as fallback | ||||||||||||||||||||||||||||||||
| dir_of_file = os.path.dirname(__file__) | ||||||||||||||||||||||||||||||||
| potential_path = os.path.join(dir_of_file, os.path.basename(path)) | ||||||||||||||||||||||||||||||||
| if os.path.exists(potential_path): | ||||||||||||||||||||||||||||||||
| path = potential_path | ||||||||||||||||||||||||||||||||
|
Comment on lines
+29
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion | 🟠 Major Simplify path resolution logic. The nested conditionals for path resolution (lines 29-37) are hard to follow and have an unusual structure with an empty ♻️ Cleaner path resolution- # Resolve absolute path if relative path is provided
- if not os.path.isabs(path):
- if os.path.exists(path):
- pass
- else:
- # Check relative to this file location as fallback
- dir_of_file = os.path.dirname(__file__)
- potential_path = os.path.join(dir_of_file, os.path.basename(path))
- if os.path.exists(potential_path):
- path = potential_path
+ # Resolve relative paths: first try as-is, then relative to this module
+ if not os.path.isabs(path) and not os.path.exists(path):
+ module_dir = os.path.dirname(__file__)
+ fallback_path = os.path.join(module_dir, os.path.basename(path))
+ if os.path.exists(fallback_path):
+ path = fallback_path📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if not os.path.exists(path): | ||||||||||||||||||||||||||||||||
| logger.error(f"Celebrity encodings file not found: {path}") | ||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||
| with open(path, 'rb') as f: | ||||||||||||||||||||||||||||||||
| data = pickle.load(f) | ||||||||||||||||||||||||||||||||
|
Comment on lines
+43
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CRITICAL: Pickle deserialization security risk. Using Consider these alternatives:
🔒 Example fix using NumPy's safer format# When saving encodings (in your encoding generation script):
np.savez_compressed(
'celebrity_encodings.npz',
names=names_array,
embeddings=embeddings_array
)
# In _load_and_optimize:
try:
data = np.load(path, allow_pickle=False) # Disable pickle for safety
self.names = data['names']
self.embeddings_matrix = data['embeddings']
logger.info(f"Loaded {len(self.names)} celebrity encodings.")
except Exception as e:
logger.error(f"Failed to load celebrity encodings: {e}")Alternatively, if you must use pickle for backward compatibility, add this warning: try:
+ # WARNING: Only load pickle files from trusted sources
+ # Pickle can execute arbitrary code during deserialization
with open(path, 'rb') as f:
data = pickle.load(f)📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Data is expected to be { "Name": [embedding1, embedding2], ... } | ||||||||||||||||||||||||||||||||
| if not isinstance(data, dict): | ||||||||||||||||||||||||||||||||
| logger.error("Loaded data is not a dictionary.") | ||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| all_names = [] | ||||||||||||||||||||||||||||||||
| all_embeddings = [] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for name, embeddings_list in data.items(): | ||||||||||||||||||||||||||||||||
| for embedding in embeddings_list: | ||||||||||||||||||||||||||||||||
| all_names.append(name) | ||||||||||||||||||||||||||||||||
| all_embeddings.append(embedding) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if all_embeddings: | ||||||||||||||||||||||||||||||||
| self.names = np.array(all_names) | ||||||||||||||||||||||||||||||||
| self.embeddings_matrix = np.array(all_embeddings) | ||||||||||||||||||||||||||||||||
| logger.info(f"Loaded {len(self.names)} celebrity encodings for {len(data)} identities.") | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| logger.warning("No embeddings found in the loaded data.") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||
| logger.error(f"Failed to load celebrity encodings: {e}") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def identify_face(self, unknown_embedding, threshold=0.7): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Identifies the celebrity from the given embedding using Euclidean distance. | ||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||
| unknown_embedding (np.ndarray): The 128D embedding of the face to identify. | ||||||||||||||||||||||||||||||||
| threshold (float): The distance threshold for a match. | ||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||
| str: The name of the celebrity if found, otherwise None. | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| if self.embeddings_matrix is None or unknown_embedding is None: | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| # Ensure unknown_embedding is the correct shape for broadcasting | ||||||||||||||||||||||||||||||||
| # Calculate Euclidean distance: sqrt(sum((x - y)^2)) | ||||||||||||||||||||||||||||||||
| # axis=1 to calculate distance for each row in the matrix | ||||||||||||||||||||||||||||||||
| distances = np.linalg.norm(self.embeddings_matrix - unknown_embedding, axis=1) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| min_distance_idx = np.argmin(distances) | ||||||||||||||||||||||||||||||||
| min_distance = distances[min_distance_idx] | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| logger.debug(f"Closest match: {self.names[min_distance_idx]} with distance: {min_distance}") | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if min_distance < threshold: | ||||||||||||||||||||||||||||||||
| return self.names[min_distance_idx] | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||
|
Comment on lines
+70
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add input validation for embedding dimension. The ✅ Proposed validation def identify_face(self, unknown_embedding, threshold=0.7):
"""
Identifies the celebrity from the given embedding using Euclidean distance.
Args:
unknown_embedding (np.ndarray): The 128D embedding of the face to identify.
threshold (float): The distance threshold for a match.
Returns:
str: The name of the celebrity if found, otherwise None.
"""
if self.embeddings_matrix is None or unknown_embedding is None:
return None
+
+ # Validate embedding shape
+ if unknown_embedding.shape[-1] != 128:
+ logger.error(f"Invalid embedding dimension: {unknown_embedding.shape}. Expected 128D.")
+ return None
+
+ # Ensure 1D array for distance calculation
+ if unknown_embedding.ndim > 1:
+ unknown_embedding = unknown_embedding.flatten()
# Calculate Euclidean distance
distances = np.linalg.norm(self.embeddings_matrix - unknown_embedding, axis=1)🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,10 +2,12 @@ | |
|
|
||
| import cv2 | ||
| from app.models.FaceNet import FaceNet | ||
| from app.models.CelebrityMatcher import CelebrityMatcher | ||
| from app.utils.FaceNet import FaceNet_util_preprocess_image, FaceNet_util_get_model_path | ||
| from app.utils.YOLO import YOLO_util_get_model_path | ||
| from app.models.YOLO import YOLO | ||
| from app.database.faces import db_insert_face_embeddings_by_image_id | ||
| from app.database.face_clusters import db_get_or_create_cluster_by_name | ||
| from app.logging.setup_logging import get_logger | ||
|
|
||
| # Initialize logger | ||
|
|
@@ -20,8 +22,9 @@ def __init__(self): | |
| iou_threshold=0.45, | ||
| ) | ||
| self.facenet = FaceNet(FaceNet_util_get_model_path()) | ||
| self.celebrity_matcher = CelebrityMatcher() | ||
| self._initialized = True | ||
| logger.info("FaceDetector initialized with YOLO and FaceNet models.") | ||
| logger.info("FaceDetector initialized with YOLO, FaceNet, and CelebrityMatcher.") | ||
|
|
||
| def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): | ||
| img = cv2.imread(image_path) | ||
|
|
@@ -33,7 +36,7 @@ def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): | |
| logger.debug(f"Face detection boxes: {boxes}") | ||
| logger.info(f"Detected {len(boxes)} faces in image {image_id}.") | ||
|
|
||
| processed_faces, embeddings, bboxes, confidences = [], [], [], [] | ||
| processed_faces, embeddings, bboxes, confidences, cluster_ids = [], [], [], [], [] | ||
|
|
||
| for box, score in zip(boxes, scores): | ||
| if score > self.yolo_detector.conf_threshold: | ||
|
|
@@ -55,9 +58,18 @@ def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): | |
| embedding = self.facenet.get_embedding(processed_face) | ||
| embeddings.append(embedding) | ||
|
|
||
| # Match celebrity | ||
| name = self.celebrity_matcher.identify_face(embedding) | ||
| if name: | ||
| logger.info(f"Identified {name} in image {image_id}") | ||
| cluster_id = db_get_or_create_cluster_by_name(name) | ||
| cluster_ids.append(cluster_id) | ||
| else: | ||
| cluster_ids.append(None) | ||
|
Comment on lines
+61
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for database operations. The call to 🛡️ Proposed fix with error handling # Match celebrity
name = self.celebrity_matcher.identify_face(embedding)
if name:
logger.info(f"Identified {name} in image {image_id}")
- cluster_id = db_get_or_create_cluster_by_name(name)
- cluster_ids.append(cluster_id)
+ try:
+ cluster_id = db_get_or_create_cluster_by_name(name)
+ cluster_ids.append(cluster_id)
+ except Exception as e:
+ logger.error(f"Failed to get/create cluster for {name}: {e}")
+ cluster_ids.append(None)
else:
cluster_ids.append(None)🤖 Prompt for AI Agents |
||
|
|
||
| if not forSearch and embeddings: | ||
| db_insert_face_embeddings_by_image_id( | ||
| image_id, embeddings, confidence=confidences, bbox=bboxes | ||
| image_id, embeddings, confidence=confidences, bbox=bboxes, cluster_id=cluster_ids | ||
| ) | ||
|
|
||
| return { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from fastapi import APIRouter, BackgroundTasks | ||
| from app.models.CelebrityMatcher import CelebrityMatcher | ||
| from app.database.faces import db_get_all_faces_with_cluster_names, db_update_face_cluster_ids_batch | ||
| from app.database.face_clusters import db_get_or_create_cluster_by_name | ||
| from app.logging.setup_logging import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
| router = APIRouter() | ||
|
|
||
| def process_celebrity_scan(): | ||
| try: | ||
| logger.info("Starting background celebrity scan...") | ||
| matcher = CelebrityMatcher() | ||
| faces = db_get_all_faces_with_cluster_names() | ||
| logger.info(f"Scanning {len(faces)} faces for celebrity matches...") | ||
|
|
||
| updates = [] | ||
| matches_found = 0 | ||
| for face in faces: | ||
| embeddings = face["embeddings"] | ||
|
|
||
| # Identify face | ||
| name = matcher.identify_face(embeddings) | ||
|
|
||
| if name: | ||
| matches_found += 1 | ||
| # Check if already named correctly | ||
| current_name = face.get("cluster_name") | ||
|
|
||
| # If current name is different (or None), calculate cluster ID and queue update | ||
| if current_name != name: | ||
| cluster_id = db_get_or_create_cluster_by_name(name) | ||
| updates.append({"face_id": face["face_id"], "cluster_id": cluster_id}) | ||
| logger.debug(f"Face {face['face_id']} matched as {name}. Queued for update.") | ||
|
|
||
| if updates: | ||
| db_update_face_cluster_ids_batch(updates) | ||
| logger.info(f"Successfully updated {len(updates)} faces with celebrity names.") | ||
| else: | ||
| logger.info(f"Scan complete. Found {matches_found} matches, but no updates were needed.") | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error during celebrity scan: {e}") | ||
|
Comment on lines
+10
to
+43
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performance concern: loading all faces into memory. Line 14 loads ALL faces from the database into memory at once. For large galleries with thousands or tens of thousands of faces, this could cause memory pressure or even OOM errors. Additionally, each call to ⚡ Proposed optimization with batched processingConsider processing faces in batches and reusing database connections: def process_celebrity_scan():
+ global _scan_in_progress
+ BATCH_SIZE = 100 # Process faces in batches
+
try:
logger.info("Starting background celebrity scan...")
matcher = CelebrityMatcher()
- faces = db_get_all_faces_with_cluster_names()
- logger.info(f"Scanning {len(faces)} faces for celebrity matches...")
+
+ # Get total count for logging
+ import sqlite3
+ from app.config.settings import DATABASE_PATH
+ conn = sqlite3.connect(DATABASE_PATH)
+ cursor = conn.cursor()
+ cursor.execute("SELECT COUNT(*) FROM faces")
+ total_faces = cursor.fetchone()[0]
+ logger.info(f"Scanning {total_faces} faces for celebrity matches...")
updates = []
matches_found = 0
- for face in faces:
+ processed = 0
+
+ # Process in batches using LIMIT/OFFSET
+ while processed < total_faces:
+ cursor.execute("""
+ SELECT f.face_id, f.embeddings, fc.cluster_name
+ FROM faces f
+ LEFT JOIN face_clusters fc ON f.cluster_id = fc.cluster_id
+ ORDER BY f.face_id
+ LIMIT ? OFFSET ?
+ """, (BATCH_SIZE, processed))
+
+ batch_faces = []
+ for row in cursor.fetchall():
+ face_id, embeddings_json, cluster_name = row
+ import numpy as np
+ import json
+ embeddings = np.array(json.loads(embeddings_json))
+ batch_faces.append({
+ "face_id": face_id,
+ "embeddings": embeddings,
+ "cluster_name": cluster_name,
+ })
+
+ if not batch_faces:
+ break
+
+ for face in batch_faces:
+ embeddings = face["embeddings"]
-
- # Identify face
- name = matcher.identify_face(embeddings)
-
- if name:
- matches_found += 1
- # Check if already named correctly
- current_name = face.get("cluster_name")
-
- # If current name is different (or None), calculate cluster ID and queue update
- if current_name != name:
- cluster_id = db_get_or_create_cluster_by_name(name)
- updates.append({"face_id": face["face_id"], "cluster_id": cluster_id})
- logger.debug(f"Face {face['face_id']} matched as {name}. Queued for update.")
+
+ name = matcher.identify_face(embeddings)
+
+ if name:
+ matches_found += 1
+ current_name = face.get("cluster_name")
+
+ if current_name != name:
+ # Reuse cursor for cluster creation
+ cluster_id = db_get_or_create_cluster_by_name(name, cursor=cursor)
+ updates.append({"face_id": face["face_id"], "cluster_id": cluster_id})
+ logger.debug(f"Face {face['face_id']} matched as {name}. Queued for update.")
+
+ processed += len(batch_faces)
+ logger.info(f"Processed {processed}/{total_faces} faces ({matches_found} matches found so far)...")
if updates:
- db_update_face_cluster_ids_batch(updates)
+ db_update_face_cluster_ids_batch(updates, cursor=cursor)
logger.info(f"Successfully updated {len(updates)} faces with celebrity names.")
else:
logger.info(f"Scan complete. Found {matches_found} matches, but no updates were needed.")
-
+
+ conn.commit()
except Exception as e:
logger.error(f"Error during celebrity scan: {e}")
+ if 'conn' in locals():
+ conn.rollback()
+ finally:
+ if 'conn' in locals():
+ conn.close()
+ with _scan_lock:
+ _scan_in_progress = FalseNote: This requires updating
🤖 Prompt for AI Agents |
||
|
|
||
| @router.post("/scan") | ||
| def scan_celebrities(background_tasks: BackgroundTasks): | ||
| """ | ||
| Triggers a background scan of all existing faces to identify celebrities. | ||
| """ | ||
| background_tasks.add_task(process_celebrity_scan) | ||
| return {"message": "Celebrity scan started in background."} | ||
|
Comment on lines
+45
to
+51
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add authentication and prevent concurrent scans. The endpoint lacks authentication and has no protection against concurrent scans. Any user can trigger multiple celebrity scans simultaneously, which could overwhelm the system or cause database contention. 🔐 Proposed fix with authentication and concurrency controlAdd a simple flag to prevent concurrent scans and require authentication: +import threading
+
+# Module-level lock to prevent concurrent scans
+_scan_lock = threading.Lock()
+_scan_in_progress = False
+
@router.post("/scan")
-def scan_celebrities(background_tasks: BackgroundTasks):
+async def scan_celebrities(background_tasks: BackgroundTasks):
"""
Triggers a background scan of all existing faces to identify celebrities.
+ Returns 409 if a scan is already in progress.
"""
+ global _scan_in_progress
+
+ # Check if scan is already running
+ if _scan_in_progress:
+ return {"message": "Celebrity scan already in progress."}, 409
+
+ # Mark scan as in progress
+ with _scan_lock:
+ if _scan_in_progress:
+ return {"message": "Celebrity scan already in progress."}, 409
+ _scan_in_progress = True
+
background_tasks.add_task(process_celebrity_scan)
return {"message": "Celebrity scan started in background."}Update def process_celebrity_scan():
+ global _scan_in_progress
try:
logger.info("Starting background celebrity scan...")
# ... existing code ...
except Exception as e:
logger.error(f"Error during celebrity scan: {e}")
+ finally:
+ with _scan_lock:
+ _scan_in_progress = False
+ logger.info("Celebrity scan completed.")Authentication note: If this application has authentication middleware, ensure this endpoint is protected. If not, consider adding at least a simple API key or rate limiting.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,137 @@ | ||
| import cv2 | ||
| import numpy as np | ||
| import os | ||
| import sys | ||
|
|
||
| # Ensure backend directory is in python path to allow imports from app.* | ||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| if current_dir not in sys.path: | ||
| sys.path.insert(0, current_dir) # Insert at the beginning to avoid conflicts | ||
|
|
||
| # We don't need the parent directory if we are running from backend and 'app' is inside backend | ||
| # Removing parent_dir addition to avoid conflict with root's app.py | ||
|
|
||
| try: | ||
| from app.models.FaceDetector import FaceDetector | ||
| from app.models.FaceNet import FaceNet | ||
| from app.models.CelebrityMatcher import CelebrityMatcher | ||
| from app.utils.FaceNet import FaceNet_util_preprocess_image | ||
| from app.logging.setup_logging import get_logger | ||
| except ImportError as e: | ||
| print(f"Error importing modules: {e}") | ||
| print("Please make sure you are running this script from the backend directory or with the correct PYTHONPATH.") | ||
| sys.exit(1) | ||
|
|
||
| # Initialize logger (optional for demo script but good practice) | ||
| logger = get_logger(__name__) | ||
|
|
||
| def run_celebrity_detection_demo(image_path): | ||
| print(f"--- Starting Celebrity Detection on {image_path} ---") | ||
|
|
||
| # 1. Initialize Models | ||
| print("Initializing models...") | ||
| try: | ||
| # Note: FaceDetector initializes its own FaceNet internally, but for this custom pipeline | ||
| # as requested, we might want to use individual components if FaceDetector is too coupled. | ||
| # However, FaceDetector.yolo_detector is accessible. | ||
| face_detector = FaceDetector() | ||
|
|
||
| # We need a separate FaceNet instance for our explicit embedding step | ||
| # if we don't want to use face_detector internals, | ||
| # OR we can reuse face_detector.facenet if it exposes what we need. | ||
| # FaceDetector has self.facenet. | ||
| facenet = face_detector.facenet | ||
|
|
||
| # Use default path (rel to CelebrityMatcher file) | ||
| celebrity_matcher = CelebrityMatcher() | ||
|
|
||
| except Exception as e: | ||
| print(f"Failed to initialize models: {e}") | ||
| return | ||
|
|
||
| # 2. Load Image | ||
| if not os.path.exists(image_path): | ||
| print(f"Image not found at {image_path}") | ||
| return | ||
|
|
||
| img = cv2.imread(image_path) | ||
| if img is None: | ||
| print(f"Failed to load image from {image_path}") | ||
| return | ||
|
|
||
| # 3. Detect Faces | ||
| # Using the YOLO detector inside FaceDetector as requested ("Use FaceDetector to find face bounding boxes") | ||
| # FaceDetector.detect_faces() does everything, but we want to demonstrate the pipeline steps. | ||
| # So we access the underlying YOLO detector. | ||
| print("Detecting faces...") | ||
| # The yolo_detector call returns boxes, scores, class_ids | ||
| boxes, scores, class_ids = face_detector.yolo_detector(img) | ||
|
|
||
| if len(boxes) == 0: | ||
| print("No faces detected.") | ||
| return | ||
|
|
||
| print(f"Found {len(boxes)} faces.") | ||
|
|
||
| # 4. Process Each Face | ||
| for i, (box, score) in enumerate(zip(boxes, scores)): | ||
| # Filter by confidence if needed (YOLO class usually handles this internally or returns all) | ||
| # FaceDetector uses 0.45 threshold. | ||
| if score < face_detector.yolo_detector.conf_threshold: | ||
| continue | ||
|
|
||
| x1, y1, x2, y2 = map(int, box) | ||
| print(f"Processing Face {i+1} at [{x1}, {y1}, {x2}, {y2}]...") | ||
|
|
||
| # Crop Face (with some padding like FaceDetector does) | ||
| padding = 20 | ||
| h, w, _ = img.shape | ||
| face_img = img[ | ||
| max(0, y1 - padding) : min(h, y2 + padding), | ||
| max(0, x1 - padding) : min(w, x2 + padding) | ||
| ] | ||
|
|
||
| if face_img.size == 0: | ||
| print("Empty crop, skipping.") | ||
| continue | ||
|
|
||
| # 5. Preprocess for FaceNet (Resize to 160x160) | ||
| # The user requested: "preprocess it (resize to 160x160)" | ||
| # FaceNet_util_preprocess_image handles: Resize -> RGB -> Transpose -> ExpandDims -> Normalize | ||
| # We'll use the util to ensure compatibility with the model. | ||
| # If manual resize is strictly required separately: | ||
| # resized_face = cv2.resize(face_img, (160, 160)) | ||
| # preprocessed_face = FaceNet_util_preprocess_image_from_resized(resized_face) | ||
| # But FaceNet_util_preprocess_image includes the resize, so we use it directly. | ||
|
|
||
| try: | ||
| # FaceNet expects the processed tensor | ||
| preprocessed_face = FaceNet_util_preprocess_image(face_img) | ||
|
|
||
| # 6. Get Embedding | ||
| embedding = facenet.get_embedding(preprocessed_face) | ||
|
|
||
| # 7. Match Celebrity | ||
| name = celebrity_matcher.identify_face(embedding) | ||
|
|
||
| if name: | ||
| print(f"Result: Found {name} at [{x1}, {y1}, {x2}, {y2}]") | ||
| # Optional: specific logging format | ||
| else: | ||
| print(f"Result: Unknown person at [{x1}, {y1}, {x2}, {y2}]") | ||
|
|
||
| except Exception as e: | ||
| print(f"Error processing face {i+1}: {e}") | ||
|
|
||
| print("--- Scanning Complete ---") | ||
|
|
||
|
Comment on lines
+28
to
+127
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add resource cleanup for models. The demo initializes 🧹 Proposed fix with proper cleanup def run_celebrity_detection_demo(image_path):
print(f"--- Starting Celebrity Detection on {image_path} ---")
# 1. Initialize Models
print("Initializing models...")
try:
face_detector = FaceDetector()
facenet = face_detector.facenet
celebrity_matcher = CelebrityMatcher()
except Exception as e:
print(f"Failed to initialize models: {e}")
return
+ try:
- # 2. Load Image
- if not os.path.exists(image_path):
+ # 2. Load Image
+ if not os.path.exists(image_path):
+ print(f"Image not found at {image_path}")
+ return
+
# ... rest of the function logic ...
+
+ print("--- Scanning Complete ---")
+ finally:
+ # Clean up resources
+ if face_detector:
+ face_detector.close()
- print("--- Scanning Complete ---")
🤖 Prompt for AI Agents |
||
| if __name__ == "__main__": | ||
| # Example usage | ||
| # You can pass an image path as an argument | ||
| if len(sys.argv) > 1: | ||
| target_image = sys.argv[1] | ||
| else: | ||
| # Default or dummy path for demonstration | ||
| target_image = "tests/inputs/sample_celebrity.jpg" # Adjust as needed | ||
|
|
||
| run_celebrity_detection_demo(target_image) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| import {apiClient} from '../axiosConfig'; | ||
|
|
||
| export const scanCelebrities = async () => { | ||
| const response = await apiClient.post('/celebrity/scan'); | ||
| return response.data; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical race condition in get-or-create pattern.
The check-then-insert logic is vulnerable to race conditions when multiple threads call this function concurrently with the same celebrity name. If two calls pass the SELECT check before either completes the INSERT, one will fail with a constraint violation.
Additionally, this function opens a new database connection for every call. When invoked in a loop (as in
backend/app/routes/celebrity.pyline 32), this creates many short-lived connections, which is inefficient.🔒 Proposed fix using INSERT OR IGNORE with proper transaction handling
This approach:
🤖 Prompt for AI Agents