diff --git a/__pycache__/app.cpython-310.pyc b/__pycache__/app.cpython-310.pyc new file mode 100644 index 000000000..8efd1dcf0 Binary files /dev/null and b/__pycache__/app.cpython-310.pyc differ diff --git a/backend/app/database/face_clusters.py b/backend/app/database/face_clusters.py index ceac7f556..34cd07b1e 100644 --- a/backend/app/database/face_clusters.py +++ b/backend/app/database/face_clusters.py @@ -1,4 +1,5 @@ import sqlite3 +import uuid from typing import Optional, List, Dict, TypedDict, Union from app.config.settings import DATABASE_PATH @@ -349,3 +350,29 @@ def db_get_images_by_cluster_id( return images finally: conn.close() + + +def db_get_or_create_cluster_by_name(name: str) -> str: + """Gets an existing cluster ID by name or creates a new one.""" + conn = sqlite3.connect(DATABASE_PATH) + cursor = conn.cursor() + try: + # Check if exists + cursor.execute( + "SELECT cluster_id FROM face_clusters WHERE cluster_name = ?", (name,) + ) + result = cursor.fetchone() + if result: + return result[0] + + # Create new + cluster_id = str(uuid.uuid4()) + cursor.execute( + "INSERT INTO face_clusters (cluster_id, cluster_name) VALUES (?, ?)", + (cluster_id, name), + ) + conn.commit() + return cluster_id + finally: + conn.close() + diff --git a/backend/app/models/CelebrityMatcher.py b/backend/app/models/CelebrityMatcher.py new file mode 100644 index 000000000..9fb6b5d6e --- /dev/null +++ b/backend/app/models/CelebrityMatcher.py @@ -0,0 +1,97 @@ +import pickle +import numpy as np +import os +from app.logging.setup_logging import get_logger + +logger = get_logger(__name__) + +class CelebrityMatcher: + def __init__(self, encodings_path=None): + """ + Initializes the CelebrityMatcher by loading encodings and optimizing them into a matrix. + + Args: + encodings_path (str): Path to the pickle file containing celebrity encodings. + If None, defaults to 'celebrity_encodings.pkl' in this directory. + """ + if encodings_path is None: + encodings_path = os.path.join(os.path.dirname(__file__), 'celebrity_encodings.pkl') + + self.names = [] + self.embeddings_matrix = None + self._load_and_optimize(encodings_path) + + def _load_and_optimize(self, path): + """ + Loads the pickle file and converts the dictionary of lists into a NumPy matrix. + """ + # Resolve absolute path if relative path is provided + if not os.path.isabs(path): + if os.path.exists(path): + pass + else: + # Check relative to this file location as fallback + dir_of_file = os.path.dirname(__file__) + potential_path = os.path.join(dir_of_file, os.path.basename(path)) + if os.path.exists(potential_path): + path = potential_path + + if not os.path.exists(path): + logger.error(f"Celebrity encodings file not found: {path}") + return + + try: + with open(path, 'rb') as f: + data = pickle.load(f) + + # Data is expected to be { "Name": [embedding1, embedding2], ... } + if not isinstance(data, dict): + logger.error("Loaded data is not a dictionary.") + return + + all_names = [] + all_embeddings = [] + + for name, embeddings_list in data.items(): + for embedding in embeddings_list: + all_names.append(name) + all_embeddings.append(embedding) + + if all_embeddings: + self.names = np.array(all_names) + self.embeddings_matrix = np.array(all_embeddings) + logger.info(f"Loaded {len(self.names)} celebrity encodings for {len(data)} identities.") + else: + logger.warning("No embeddings found in the loaded data.") + + except Exception as e: + logger.error(f"Failed to load celebrity encodings: {e}") + + def identify_face(self, unknown_embedding, threshold=0.7): + """ + Identifies the celebrity from the given embedding using Euclidean distance. + + Args: + unknown_embedding (np.ndarray): The 128D embedding of the face to identify. + threshold (float): The distance threshold for a match. + + Returns: + str: The name of the celebrity if found, otherwise None. + """ + if self.embeddings_matrix is None or unknown_embedding is None: + return None + + # Ensure unknown_embedding is the correct shape for broadcasting + # Calculate Euclidean distance: sqrt(sum((x - y)^2)) + # axis=1 to calculate distance for each row in the matrix + distances = np.linalg.norm(self.embeddings_matrix - unknown_embedding, axis=1) + + min_distance_idx = np.argmin(distances) + min_distance = distances[min_distance_idx] + + logger.debug(f"Closest match: {self.names[min_distance_idx]} with distance: {min_distance}") + + if min_distance < threshold: + return self.names[min_distance_idx] + else: + return None diff --git a/backend/app/models/FaceDetector.py b/backend/app/models/FaceDetector.py index 9e10fd5fc..6893c1970 100644 --- a/backend/app/models/FaceDetector.py +++ b/backend/app/models/FaceDetector.py @@ -2,10 +2,12 @@ import cv2 from app.models.FaceNet import FaceNet +from app.models.CelebrityMatcher import CelebrityMatcher from app.utils.FaceNet import FaceNet_util_preprocess_image, FaceNet_util_get_model_path from app.utils.YOLO import YOLO_util_get_model_path from app.models.YOLO import YOLO from app.database.faces import db_insert_face_embeddings_by_image_id +from app.database.face_clusters import db_get_or_create_cluster_by_name from app.logging.setup_logging import get_logger # Initialize logger @@ -20,8 +22,9 @@ def __init__(self): iou_threshold=0.45, ) self.facenet = FaceNet(FaceNet_util_get_model_path()) + self.celebrity_matcher = CelebrityMatcher() self._initialized = True - logger.info("FaceDetector initialized with YOLO and FaceNet models.") + logger.info("FaceDetector initialized with YOLO, FaceNet, and CelebrityMatcher.") def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): img = cv2.imread(image_path) @@ -33,7 +36,7 @@ def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): logger.debug(f"Face detection boxes: {boxes}") logger.info(f"Detected {len(boxes)} faces in image {image_id}.") - processed_faces, embeddings, bboxes, confidences = [], [], [], [] + processed_faces, embeddings, bboxes, confidences, cluster_ids = [], [], [], [], [] for box, score in zip(boxes, scores): if score > self.yolo_detector.conf_threshold: @@ -55,9 +58,18 @@ def detect_faces(self, image_id: str, image_path: str, forSearch: bool = False): embedding = self.facenet.get_embedding(processed_face) embeddings.append(embedding) + # Match celebrity + name = self.celebrity_matcher.identify_face(embedding) + if name: + logger.info(f"Identified {name} in image {image_id}") + cluster_id = db_get_or_create_cluster_by_name(name) + cluster_ids.append(cluster_id) + else: + cluster_ids.append(None) + if not forSearch and embeddings: db_insert_face_embeddings_by_image_id( - image_id, embeddings, confidence=confidences, bbox=bboxes + image_id, embeddings, confidence=confidences, bbox=bboxes, cluster_id=cluster_ids ) return { diff --git a/backend/app/models/celebrity_encodings.pkl b/backend/app/models/celebrity_encodings.pkl new file mode 100644 index 000000000..284db6644 Binary files /dev/null and b/backend/app/models/celebrity_encodings.pkl differ diff --git a/backend/app/routes/celebrity.py b/backend/app/routes/celebrity.py new file mode 100644 index 000000000..9a1b48462 --- /dev/null +++ b/backend/app/routes/celebrity.py @@ -0,0 +1,51 @@ +from fastapi import APIRouter, BackgroundTasks +from app.models.CelebrityMatcher import CelebrityMatcher +from app.database.faces import db_get_all_faces_with_cluster_names, db_update_face_cluster_ids_batch +from app.database.face_clusters import db_get_or_create_cluster_by_name +from app.logging.setup_logging import get_logger + +logger = get_logger(__name__) +router = APIRouter() + +def process_celebrity_scan(): + try: + logger.info("Starting background celebrity scan...") + matcher = CelebrityMatcher() + faces = db_get_all_faces_with_cluster_names() + logger.info(f"Scanning {len(faces)} faces for celebrity matches...") + + updates = [] + matches_found = 0 + for face in faces: + embeddings = face["embeddings"] + + # Identify face + name = matcher.identify_face(embeddings) + + if name: + matches_found += 1 + # Check if already named correctly + current_name = face.get("cluster_name") + + # If current name is different (or None), calculate cluster ID and queue update + if current_name != name: + cluster_id = db_get_or_create_cluster_by_name(name) + updates.append({"face_id": face["face_id"], "cluster_id": cluster_id}) + logger.debug(f"Face {face['face_id']} matched as {name}. Queued for update.") + + if updates: + db_update_face_cluster_ids_batch(updates) + logger.info(f"Successfully updated {len(updates)} faces with celebrity names.") + else: + logger.info(f"Scan complete. Found {matches_found} matches, but no updates were needed.") + + except Exception as e: + logger.error(f"Error during celebrity scan: {e}") + +@router.post("/scan") +def scan_celebrities(background_tasks: BackgroundTasks): + """ + Triggers a background scan of all existing faces to identify celebrities. + """ + background_tasks.add_task(process_celebrity_scan) + return {"message": "Celebrity scan started in background."} diff --git a/backend/celebrity_detection_demo.py b/backend/celebrity_detection_demo.py new file mode 100644 index 000000000..20ad596a6 --- /dev/null +++ b/backend/celebrity_detection_demo.py @@ -0,0 +1,137 @@ +import cv2 +import numpy as np +import os +import sys + +# Ensure backend directory is in python path to allow imports from app.* +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.insert(0, current_dir) # Insert at the beginning to avoid conflicts + +# We don't need the parent directory if we are running from backend and 'app' is inside backend +# Removing parent_dir addition to avoid conflict with root's app.py + +try: + from app.models.FaceDetector import FaceDetector + from app.models.FaceNet import FaceNet + from app.models.CelebrityMatcher import CelebrityMatcher + from app.utils.FaceNet import FaceNet_util_preprocess_image + from app.logging.setup_logging import get_logger +except ImportError as e: + print(f"Error importing modules: {e}") + print("Please make sure you are running this script from the backend directory or with the correct PYTHONPATH.") + sys.exit(1) + +# Initialize logger (optional for demo script but good practice) +logger = get_logger(__name__) + +def run_celebrity_detection_demo(image_path): + print(f"--- Starting Celebrity Detection on {image_path} ---") + + # 1. Initialize Models + print("Initializing models...") + try: + # Note: FaceDetector initializes its own FaceNet internally, but for this custom pipeline + # as requested, we might want to use individual components if FaceDetector is too coupled. + # However, FaceDetector.yolo_detector is accessible. + face_detector = FaceDetector() + + # We need a separate FaceNet instance for our explicit embedding step + # if we don't want to use face_detector internals, + # OR we can reuse face_detector.facenet if it exposes what we need. + # FaceDetector has self.facenet. + facenet = face_detector.facenet + + # Use default path (rel to CelebrityMatcher file) + celebrity_matcher = CelebrityMatcher() + + except Exception as e: + print(f"Failed to initialize models: {e}") + return + + # 2. Load Image + if not os.path.exists(image_path): + print(f"Image not found at {image_path}") + return + + img = cv2.imread(image_path) + if img is None: + print(f"Failed to load image from {image_path}") + return + + # 3. Detect Faces + # Using the YOLO detector inside FaceDetector as requested ("Use FaceDetector to find face bounding boxes") + # FaceDetector.detect_faces() does everything, but we want to demonstrate the pipeline steps. + # So we access the underlying YOLO detector. + print("Detecting faces...") + # The yolo_detector call returns boxes, scores, class_ids + boxes, scores, class_ids = face_detector.yolo_detector(img) + + if len(boxes) == 0: + print("No faces detected.") + return + + print(f"Found {len(boxes)} faces.") + + # 4. Process Each Face + for i, (box, score) in enumerate(zip(boxes, scores)): + # Filter by confidence if needed (YOLO class usually handles this internally or returns all) + # FaceDetector uses 0.45 threshold. + if score < face_detector.yolo_detector.conf_threshold: + continue + + x1, y1, x2, y2 = map(int, box) + print(f"Processing Face {i+1} at [{x1}, {y1}, {x2}, {y2}]...") + + # Crop Face (with some padding like FaceDetector does) + padding = 20 + h, w, _ = img.shape + face_img = img[ + max(0, y1 - padding) : min(h, y2 + padding), + max(0, x1 - padding) : min(w, x2 + padding) + ] + + if face_img.size == 0: + print("Empty crop, skipping.") + continue + + # 5. Preprocess for FaceNet (Resize to 160x160) + # The user requested: "preprocess it (resize to 160x160)" + # FaceNet_util_preprocess_image handles: Resize -> RGB -> Transpose -> ExpandDims -> Normalize + # We'll use the util to ensure compatibility with the model. + # If manual resize is strictly required separately: + # resized_face = cv2.resize(face_img, (160, 160)) + # preprocessed_face = FaceNet_util_preprocess_image_from_resized(resized_face) + # But FaceNet_util_preprocess_image includes the resize, so we use it directly. + + try: + # FaceNet expects the processed tensor + preprocessed_face = FaceNet_util_preprocess_image(face_img) + + # 6. Get Embedding + embedding = facenet.get_embedding(preprocessed_face) + + # 7. Match Celebrity + name = celebrity_matcher.identify_face(embedding) + + if name: + print(f"Result: Found {name} at [{x1}, {y1}, {x2}, {y2}]") + # Optional: specific logging format + else: + print(f"Result: Unknown person at [{x1}, {y1}, {x2}, {y2}]") + + except Exception as e: + print(f"Error processing face {i+1}: {e}") + + print("--- Scanning Complete ---") + +if __name__ == "__main__": + # Example usage + # You can pass an image path as an argument + if len(sys.argv) > 1: + target_image = sys.argv[1] + else: + # Default or dummy path for demonstration + target_image = "tests/inputs/sample_celebrity.jpg" # Adjust as needed + + run_celebrity_detection_demo(target_image) diff --git a/backend/main.py b/backend/main.py index db591cd97..c386b1330 100644 --- a/backend/main.py +++ b/backend/main.py @@ -131,7 +131,6 @@ async def root(): user_preferences_router, prefix="/user-preferences", tags=["User Preferences"] ) - # Entry point for running with: python3 main.py if __name__ == "__main__": multiprocessing.freeze_support() # Required for Windows diff --git a/frontend/src/api/api-functions/celebrity.ts b/frontend/src/api/api-functions/celebrity.ts new file mode 100644 index 000000000..2335b7fb2 --- /dev/null +++ b/frontend/src/api/api-functions/celebrity.ts @@ -0,0 +1,6 @@ +import {apiClient} from '../axiosConfig'; + +export const scanCelebrities = async () => { + const response = await apiClient.post('/celebrity/scan'); + return response.data; +}; diff --git a/frontend/src/api/api-functions/index.ts b/frontend/src/api/api-functions/index.ts index 5d6f2fa8c..a532c0467 100644 --- a/frontend/src/api/api-functions/index.ts +++ b/frontend/src/api/api-functions/index.ts @@ -4,3 +4,5 @@ export * from './images'; export * from './folders'; export * from './user_preferences'; export * from './health'; +export * from './celebrity'; + diff --git a/frontend/src/pages/Home/Home.tsx b/frontend/src/pages/Home/Home.tsx index 83c9e5c83..0a89a5e41 100644 --- a/frontend/src/pages/Home/Home.tsx +++ b/frontend/src/pages/Home/Home.tsx @@ -1,5 +1,9 @@ import { useEffect, useRef, useState } from 'react'; import { useDispatch, useSelector } from 'react-redux'; +import { useMutation } from '@tanstack/react-query'; +import { scanCelebrities } from '@/api/api-functions'; +import { Button } from '@/components/ui/button'; +import { Sparkles } from 'lucide-react'; import { ChronologicalGallery, MonthMarker, @@ -28,6 +32,25 @@ export const Home = () => { enabled: !isSearchActive, }); + const scanMutation = useMutation({ + mutationFn: scanCelebrities, + }); + + useMutationFeedback( + { + isPending: scanMutation.isPending, + isSuccess: scanMutation.isSuccess, + isError: scanMutation.isError, + error: scanMutation.error + }, + { + loadingMessage: 'Scanning for celebrities...', + successMessage: 'Celebrity scan started in background.', + errorTitle: 'Scan Failed', + errorMessage: 'Failed to start celebrity scan.', + }, + ); + useMutationFeedback( { isPending: isLoading, isSuccess, isError, error }, { @@ -52,6 +75,19 @@ export const Home = () => { return (