Skip to content

Commit ee0ee5c

Browse files
authored
Add a newer implementation of MediaPipe (#221)
1 parent 305d4c1 commit ee0ee5c

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

lfbw/lfbw.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,50 @@
1515
from cmapy import cmap
1616
from matplotlib import colormaps
1717
import copy
18+
from pathlib import Path
19+
import urllib.request
20+
21+
import mediapipe as mp
22+
from mediapipe.tasks import python
23+
from mediapipe.tasks.python import vision
24+
25+
26+
class ImageSegmenter:
27+
28+
def __init__(self, width, height):
29+
base_options = python.BaseOptions(model_asset_path='selfie_segmenter_landscape.tflite')
30+
options = vision.ImageSegmenterOptions(
31+
base_options=base_options,
32+
output_category_mask=True
33+
)
34+
self.orig_w, self.orig_h = width, height
35+
self.target_w = 256
36+
self.target_h = int(256 * self.orig_h / self.orig_w)
37+
self.segmenter = vision.ImageSegmenter.create_from_options(options)
38+
39+
def segment(self, frame):
40+
41+
mp_frame = cv2.resize(frame, (self.target_w, self.target_h), interpolation=cv2.INTER_AREA)
42+
mp_frame = cv2.cvtColor(mp_frame, cv2.COLOR_BGR2RGB)
43+
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mp_frame)
44+
45+
segmentation_result = self.segmenter.segment(mp_image)
46+
category_mask = segmentation_result.category_mask.numpy_view()
47+
mask = (category_mask == 0).astype(np.float32)
48+
49+
# Upscale mask back to original resolution
50+
mask_upscaled = cv2.resize(mask, (self.orig_w, self.orig_h), interpolation=cv2.INTER_LINEAR)
51+
52+
# Smooth edges to reduce cutting/inconsistency
53+
mask_upscaled = cv2.GaussianBlur(mask_upscaled, (7, 7), 0)
54+
55+
return mask_upscaled
56+
57+
def close(self):
58+
if self.segmenter:
59+
self.segmenter.close()
60+
self.segmenter = None
61+
1862

1963
class RealCam:
2064
def __init__(self, src, frame_width, frame_height, frame_rate, codec):
@@ -126,10 +170,18 @@ def __init__(self, args) -> None:
126170
print(self.__dict__)
127171
sys.exit(0)
128172

129-
# slow model loading
130-
import mediapipe as mp
131-
self.classifier = mp.solutions.selfie_segmentation.SelfieSegmentation(
132-
model_selection=args.select_model)
173+
if not Path('selfie_segmenter_landscape.tflite').exists():
174+
filename = "selfie_segmenter_landscape.tflite"
175+
url = "https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_segmenter_landscape/float16/latest/selfie_segmenter_landscape.tflite"
176+
try:
177+
print("Downloading model...")
178+
with urllib.request.urlopen(url) as response:
179+
with open(filename, 'wb') as out_file:
180+
out_file.write(response.read())
181+
except Exception:
182+
print("Cannot download MediaPipe model")
183+
184+
self.classifier = ImageSegmenter(self.real_width, self.real_height)
133185

134186

135187
def resize_image(self, img, keep_aspect):
@@ -264,7 +316,7 @@ def next_frame():
264316
(self.width, self.height))
265317

266318
def compose_frame(self, frame):
267-
mask = copy.copy(self.classifier.process(frame).segmentation_mask)
319+
mask = copy.copy(self.classifier.segment(frame))
268320

269321
if self.threshold < 1:
270322
cv2.threshold(mask, self.threshold, 1, cv2.THRESH_BINARY, dst=mask)
@@ -800,6 +852,7 @@ def sigint_handler(cam, signal, frame):
800852

801853
def sigquit_handler(cam, signal, frame):
802854
print("\nKilling fake cam process")
855+
cam.classifier.close()
803856
sys.exit(0)
804857

805858

0 commit comments

Comments
 (0)