Skip to content

Commit eaf2e3e

Browse files
committed
Refactor out face detection logic to face_detection.py
1 parent 8217f96 commit eaf2e3e

File tree

2 files changed

+359
-237
lines changed

2 files changed

+359
-237
lines changed

coffee_ws/src/coffee_vision/coffee_vision/camera_node.py

Lines changed: 15 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818
from cv_bridge import CvBridge
1919

2020
from .coordinate_utils import transform_camera_to_eye_coords
21-
22-
# Models directory for face detection models
23-
MODELS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "models")
24-
os.makedirs(MODELS_DIR, exist_ok=True)
21+
from .face_detection import FaceDetector
2522

2623

2724
class FrameGrabber:
@@ -79,14 +76,11 @@ def __init__(self, node=None):
7976

8077
# Face detection
8178
self.enable_face_detection = True
82-
self.face_detector = None
83-
self.face_net = None
84-
self.face_confidence_threshold = 0.5
85-
86-
# Smoothing for face detection
87-
self.prev_faces = []
88-
self.smoothing_factor = 0.4 # Higher value = more smoothing
89-
self.smoothing_frames = 5 # Number of frames to average
79+
self.face_detector = FaceDetector(
80+
confidence_threshold=0.5,
81+
smoothing_factor=0.4,
82+
logger=self.node.get_logger() if self.node else None
83+
)
9084

9185
# # Get parameters
9286
# self.invert_x = self.get_parameter('invert_x').value
@@ -112,8 +106,6 @@ def __init__(self, node=None):
112106
self.face_ids = {} # Map of face index to recognized face ID
113107
self.last_recognition_time = 0
114108
self.recognition_timeout = 3.0 # Clear recognition data after 3 seconds
115-
116-
self.init_face_detector()
117109

118110
def publish_face_data(self, faces):
119111
"""Publish face detection data for other nodes"""
@@ -310,49 +302,7 @@ def publish_frame(self, frame):
310302
if self.node:
311303
self.node.get_logger().error(f"Error publishing frame: {e}")
312304

313-
def init_face_detector(self):
314-
"""Initialize the OpenCV DNN face detector"""
315-
try:
316-
# Try to get the models from disk, or download them if not present
317-
model_file = os.path.join(MODELS_DIR, "opencv_face_detector_uint8.pb")
318-
config_file = os.path.join(MODELS_DIR, "opencv_face_detector.pbtxt")
319-
320-
# Download the model files if they don't exist
321-
if not os.path.exists(model_file) or not os.path.exists(config_file):
322-
self.download_face_model(model_file, config_file)
323-
324-
# Load the DNN face detector
325-
self.face_net = cv2.dnn.readNet(model_file, config_file)
326-
327-
# Switch to a more accurate backend if available
328-
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
329-
self.face_net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
330-
self.face_net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
331-
else:
332-
self.face_net.setPreferableBackend(cv2.dnn.DNN_BACKEND_DEFAULT)
333-
self.face_net.setPreferableTarget(cv2.dnn.DNN_TARGET_CPU)
334-
335-
print("Face detector (OpenCV DNN) initialized successfully")
336-
except Exception as e:
337-
print(f"Error initializing face detector: {e}")
338-
self.face_net = None
339-
340-
def download_face_model(self, model_file, config_file):
341-
"""Download the face detection model if needed"""
342-
try:
343-
# Model URLs
344-
model_url = "https://github.com/spmallick/learnopencv/raw/refs/heads/master/AgeGender/opencv_face_detector_uint8.pb"
345-
config_url = "https://raw.githubusercontent.com/spmallick/learnopencv/refs/heads/master/AgeGender/opencv_face_detector.pbtxt"
346-
347-
# Download the files
348-
import urllib.request
349-
print("Downloading face detection model...")
350-
urllib.request.urlretrieve(model_url, model_file)
351-
urllib.request.urlretrieve(config_url, config_file)
352-
print("Face detection model downloaded successfully")
353-
except Exception as e:
354-
print(f"Error downloading face model: {e}")
355-
raise
305+
356306

357307
def start(self, camera_index, backend=cv2.CAP_ANY):
358308
with self.lock:
@@ -441,189 +391,16 @@ def toggle_face_detection(self, enable):
441391
"""Enable or disable face detection"""
442392
with self.lock:
443393
self.enable_face_detection = enable
444-
if enable and self.face_net is None:
445-
self.init_face_detector()
446394

447395
# Reset face tracking when toggling
448-
self.prev_faces = []
396+
if self.face_detector:
397+
self.face_detector.reset_tracking()
449398

450-
def detect_faces_dnn(self, frame):
451-
"""Detect faces using OpenCV's DNN-based face detector"""
452-
if self.face_net is None:
453-
return []
454-
455-
# Get frame dimensions
456-
h, w = frame.shape[:2]
457-
458-
# Prepare input blob for the network
459-
blob = cv2.dnn.blobFromImage(frame, 1.0, (300, 300), [104, 117, 123], False, False)
460-
self.face_net.setInput(blob)
461-
462-
# Run forward pass
463-
detections = self.face_net.forward()
464-
465-
# Parse detections
466-
faces = []
467-
for i in range(detections.shape[2]):
468-
confidence = detections[0, 0, i, 2]
469-
if confidence > self.face_confidence_threshold:
470-
# Get face bounding box
471-
x1 = int(detections[0, 0, i, 3] * w)
472-
y1 = int(detections[0, 0, i, 4] * h)
473-
x2 = int(detections[0, 0, i, 5] * w)
474-
y2 = int(detections[0, 0, i, 6] * h)
475-
476-
# Make sure the coordinates are within the frame
477-
x1 = max(0, min(x1, w-1))
478-
y1 = max(0, min(y1, h-1))
479-
x2 = max(0, min(x2, w-1))
480-
y2 = max(0, min(y2, h-1))
481-
482-
if x2 > x1 and y2 > y1: # Valid face
483-
faces.append({
484-
'x1': x1, 'y1': y1, 'x2': x2, 'y2': y2,
485-
'center_x': (x1 + x2) // 2,
486-
'center_y': (y1 + y2) // 2,
487-
'radius': max((x2 - x1), (y2 - y1)) // 2,
488-
'confidence': confidence
489-
})
490-
491-
return faces
399+
492400

493-
def smooth_face_detections(self, faces):
494-
"""Apply temporal smoothing to face detections to reduce flickering"""
495-
if not faces:
496-
# If no faces detected in current frame but we have previous faces,
497-
# decay them but keep showing them for a while
498-
if self.prev_faces:
499-
# Slowly reduce confidence of previous faces
500-
for face in self.prev_faces:
501-
face['confidence'] *= 0.8 # Decay factor
502-
503-
# Remove faces with very low confidence
504-
self.prev_faces = [f for f in self.prev_faces if f['confidence'] > 0.2]
505-
return self.prev_faces
506-
return []
507-
508-
# If we have new faces, smoothly transition to them
509-
if not self.prev_faces:
510-
# First detection, just use it
511-
self.prev_faces = faces
512-
return faces
513-
514-
# Try to match new faces with previous faces
515-
new_faces = []
516-
for new_face in faces:
517-
# Find closest previous face
518-
best_match = None
519-
min_distance = float('inf')
520-
521-
for i, prev_face in enumerate(self.prev_faces):
522-
# Calculate distance between centers
523-
dx = new_face['center_x'] - prev_face['center_x']
524-
dy = new_face['center_y'] - prev_face['center_y']
525-
distance = (dx*dx + dy*dy) ** 0.5
526-
527-
if distance < min_distance:
528-
min_distance = distance
529-
best_match = i
530-
531-
# If we found a close enough match, smooth the transition
532-
if best_match is not None and min_distance < 100: # Threshold distance
533-
prev_face = self.prev_faces[best_match]
534-
535-
# Smooth position and size
536-
smoothed_face = {
537-
'center_x': int(self.smoothing_factor * prev_face['center_x'] +
538-
(1 - self.smoothing_factor) * new_face['center_x']),
539-
'center_y': int(self.smoothing_factor * prev_face['center_y'] +
540-
(1 - self.smoothing_factor) * new_face['center_y']),
541-
'radius': int(self.smoothing_factor * prev_face['radius'] +
542-
(1 - self.smoothing_factor) * new_face['radius']),
543-
'confidence': new_face['confidence']
544-
}
545-
546-
# Calculate new bounding box from smoothed center and radius
547-
r = smoothed_face['radius']
548-
cx = smoothed_face['center_x']
549-
cy = smoothed_face['center_y']
550-
smoothed_face['x1'] = cx - r
551-
smoothed_face['y1'] = cy - r
552-
smoothed_face['x2'] = cx + r
553-
smoothed_face['y2'] = cy + r
554-
555-
new_faces.append(smoothed_face)
556-
# Remove the matched face to prevent double matching
557-
self.prev_faces.pop(best_match)
558-
else:
559-
# No match, add as new face
560-
new_faces.append(new_face)
561-
562-
# Add any remaining unmatched previous faces with decayed confidence
563-
for face in self.prev_faces:
564-
face['confidence'] *= 0.5 # Faster decay for unmatched faces
565-
if face['confidence'] > 0.3: # Only keep if still confident enough
566-
new_faces.append(face)
567-
568-
# Update previous faces for next frame
569-
self.prev_faces = new_faces
570-
return new_faces
401+
571402

572-
def draw_face_circles(self, frame, faces):
573-
"""Draw transparent circles over detected faces with IDs"""
574-
if not faces:
575-
return frame
576-
577-
# Create an overlay for transparency
578-
overlay = frame.copy()
579-
580-
# Draw circles and face data on overlay
581-
for i, face in enumerate(faces):
582-
# Get face ID if available
583-
face_id = face.get('id', 'Unknown')
584-
585-
# Choose color based on face ID
586-
if face_id != 'Unknown':
587-
# Use different color for recognized faces
588-
color = (0, 200, 255) # Orange for recognized faces
589-
else:
590-
color = (0, 255, 0) # Green for detected faces
591-
592-
# Draw circle on overlay
593-
cv2.circle(overlay,
594-
(face['center_x'], face['center_y']),
595-
face['radius'],
596-
color,
597-
-1)
598-
599-
# Draw rectangle around face
600-
cv2.rectangle(frame, (face['x1'], face['y1']), (face['x2'], face['y2']), color, 2)
601-
602-
# Display face ID and confidence
603-
face_conf = face.get('confidence', 0.0)
604-
id_text = f"ID: {face_id}" if face_id != 'Unknown' else "Unknown"
605-
conf_text = f"Conf: {face_conf:.2f}"
606-
607-
cv2.putText(frame, id_text, (face['x1'], face['y1'] - 10),
608-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
609-
cv2.putText(frame, conf_text, (face['x1'], face['y1'] + face['height'] if 'height' in face else (face['y2']-face['y1']) + 20),
610-
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
611-
612-
# Blend the overlay with the original frame for transparency
613-
alpha = 0.3 # Transparency factor
614-
cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
615-
616-
# Add indicator text if faces detected
617-
cv2.putText(frame, f"Faces: {len(faces)}", (10, 70),
618-
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
619-
620-
# Display number of recognized faces
621-
recog_count = len([f for f in faces if f.get('id', 'Unknown') != 'Unknown'])
622-
if recog_count > 0:
623-
cv2.putText(frame, f"Recognized: {recog_count}/{len(faces)}", (10, 110),
624-
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
625-
626-
return frame
403+
627404

628405
def _capture_loop(self):
629406
"""Main capture loop for camera frames"""
@@ -752,7 +529,8 @@ def _process_loop(self):
752529

753530
if should_detect:
754531
detection_start = time.time()
755-
faces = self.detect_faces_dnn(frame)
532+
faces = self.face_detector.detect_faces(frame)
533+
faces = self.face_detector.smooth_detections(faces)
756534
detection_time = time.time() - detection_start
757535

758536
# If detection took too long, increase skip frames
@@ -779,7 +557,7 @@ def _process_loop(self):
779557

780558
# Draw faces if available
781559
if self.current_faces:
782-
frame = self.draw_face_circles(frame, self.current_faces)
560+
frame = self.face_detector.draw_debug_overlay(frame, self.current_faces)
783561

784562
# Update FPS counter
785563
frame_count += 1

0 commit comments

Comments
 (0)