Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 21 additions & 45 deletions examples/pose_landmarker/raspberry_pi/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import mediapipe as mp
import numpy as np

from picamera2 import Picamera2
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.framework.formats import landmark_pb2
Expand All @@ -37,32 +38,18 @@

def run(model: str, num_poses: int,
min_pose_detection_confidence: float,
min_pose_presence_confidence: float, min_tracking_confidence: float,
min_pose_presence_confidence: float,
min_tracking_confidence: float,
output_segmentation_masks: bool,
camera_id: int, width: int, height: int) -> None:
"""Continuously run inference on images acquired from the camera.

Args:
model: Name of the pose landmarker model bundle.
num_poses: Max number of poses that can be detected by the landmarker.
min_pose_detection_confidence: The minimum confidence score for pose
detection to be considered successful.
min_pose_presence_confidence: The minimum confidence score of pose
presence score in the pose landmark detection.
min_tracking_confidence: The minimum confidence score for the pose
tracking to be considered successful.
output_segmentation_masks: Choose whether to visualize the segmentation
mask or not.
camera_id: The camera id to be passed to OpenCV.
width: The width of the frame captured from the camera.
height: The height of the frame captured from the camera.
"""

# Start capturing video input from the camera
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

width: int, height: int) -> None:

picam2 = Picamera2()
picam2.preview_configuration.main.size = (width, height)
picam2.preview_configuration.main.format = "RGB888"
picam2.preview_configuration.align()
picam2.configure("preview")
picam2.start()

# Visualization parameters
row_size = 50 # pixels
left_margin = 24 # pixels
Expand Down Expand Up @@ -99,17 +86,14 @@ def save_result(result: vision.PoseLandmarkerResult,
detector = vision.PoseLandmarker.create_from_options(options)

# Continuously capture images from the camera and run inference
while cap.isOpened():
success, image = cap.read()
if not success:
sys.exit(
'ERROR: Unable to read from webcam. Please verify your webcam settings.'
)

image = cv2.flip(image, 1)
while True:
frame = picam2.capture_array()


frame = cv2.flip(frame, 1)

# Convert the image from BGR to RGB as required by the TFLite model.
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
rgb_image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_image)

# Run pose landmarker using the model.
Expand All @@ -118,7 +102,7 @@ def save_result(result: vision.PoseLandmarkerResult,
# Show the FPS
fps_text = 'FPS = {:.1f}'.format(FPS)
text_location = (left_margin, row_size)
current_frame = image
current_frame = frame
cv2.putText(current_frame, fps_text, text_location,
cv2.FONT_HERSHEY_DUPLEX,
font_size, text_color, font_thickness, cv2.LINE_AA)
Expand All @@ -142,7 +126,7 @@ def save_result(result: vision.PoseLandmarkerResult,
if (output_segmentation_masks and DETECTION_RESULT):
if DETECTION_RESULT.segmentation_masks is not None:
segmentation_mask = DETECTION_RESULT.segmentation_masks[0].numpy_view()
mask_image = np.zeros(image.shape, dtype=np.uint8)
mask_image = np.zeros(frame.shape, dtype=np.uint8)
mask_image[:] = mask_color
condition = np.stack((segmentation_mask,) * 3, axis=-1) > 0.1
visualized_mask = np.where(condition, mask_image, current_frame)
Expand All @@ -157,7 +141,6 @@ def save_result(result: vision.PoseLandmarkerResult,
break

detector.close()
cap.release()
cv2.destroyAllWindows()


Expand Down Expand Up @@ -198,12 +181,6 @@ def main():
'mask.',
required=False,
action='store_true')
# Finding the camera ID can be very reliant on platform-dependent methods.
# One common approach is to use the fact that camera IDs are usually indexed sequentially by the OS, starting from 0.
# Here, we use OpenCV and create a VideoCapture object for each potential ID with 'cap = cv2.VideoCapture(i)'.
# If 'cap' is None or not 'cap.isOpened()', it indicates the camera ID is not available.
parser.add_argument(
'--cameraId', help='Id of camera.', required=False, default=0)
parser.add_argument(
'--frameWidth',
help='Width of frame to capture from camera.',
Expand All @@ -218,8 +195,7 @@ def main():

run(args.model, int(args.numPoses), args.minPoseDetectionConfidence,
args.minPosePresenceConfidence, args.minTrackingConfidence,
args.outputSegmentationMasks,
int(args.cameraId), args.frameWidth, args.frameHeight)
args.outputSegmentationMasks, args.frameWidth, args.frameHeight)


if __name__ == '__main__':
Expand Down