|
15 | 15 | from cmapy import cmap |
16 | 16 | from matplotlib import colormaps |
17 | 17 | import copy |
| 18 | +from pathlib import Path |
| 19 | +import urllib.request |
| 20 | + |
| 21 | +import mediapipe as mp |
| 22 | +from mediapipe.tasks import python |
| 23 | +from mediapipe.tasks.python import vision |
| 24 | + |
| 25 | + |
| 26 | +class ImageSegmenter: |
| 27 | + |
| 28 | + def __init__(self, width, height): |
| 29 | + base_options = python.BaseOptions(model_asset_path='selfie_segmenter_landscape.tflite') |
| 30 | + options = vision.ImageSegmenterOptions( |
| 31 | + base_options=base_options, |
| 32 | + output_category_mask=True |
| 33 | + ) |
| 34 | + self.orig_w, self.orig_h = width, height |
| 35 | + self.target_w = 256 |
| 36 | + self.target_h = int(256 * self.orig_h / self.orig_w) |
| 37 | + self.segmenter = vision.ImageSegmenter.create_from_options(options) |
| 38 | + |
| 39 | + def segment(self, frame): |
| 40 | + |
| 41 | + mp_frame = cv2.resize(frame, (self.target_w, self.target_h), interpolation=cv2.INTER_AREA) |
| 42 | + mp_frame = cv2.cvtColor(mp_frame, cv2.COLOR_BGR2RGB) |
| 43 | + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=mp_frame) |
| 44 | + |
| 45 | + segmentation_result = self.segmenter.segment(mp_image) |
| 46 | + category_mask = segmentation_result.category_mask.numpy_view() |
| 47 | + mask = (category_mask == 0).astype(np.float32) |
| 48 | + |
| 49 | + # Upscale mask back to original resolution |
| 50 | + mask_upscaled = cv2.resize(mask, (self.orig_w, self.orig_h), interpolation=cv2.INTER_LINEAR) |
| 51 | + |
| 52 | + # Smooth edges to reduce cutting/inconsistency |
| 53 | + mask_upscaled = cv2.GaussianBlur(mask_upscaled, (7, 7), 0) |
| 54 | + |
| 55 | + return mask_upscaled |
| 56 | + |
| 57 | + def close(self): |
| 58 | + if self.segmenter: |
| 59 | + self.segmenter.close() |
| 60 | + self.segmenter = None |
| 61 | + |
18 | 62 |
|
19 | 63 | class RealCam: |
20 | 64 | def __init__(self, src, frame_width, frame_height, frame_rate, codec): |
@@ -126,10 +170,18 @@ def __init__(self, args) -> None: |
126 | 170 | print(self.__dict__) |
127 | 171 | sys.exit(0) |
128 | 172 |
|
129 | | - # slow model loading |
130 | | - import mediapipe as mp |
131 | | - self.classifier = mp.solutions.selfie_segmentation.SelfieSegmentation( |
132 | | - model_selection=args.select_model) |
| 173 | + if not Path('selfie_segmenter_landscape.tflite').exists(): |
| 174 | + filename = "selfie_segmenter_landscape.tflite" |
| 175 | + url = "https://storage.googleapis.com/mediapipe-models/image_segmenter/selfie_segmenter_landscape/float16/latest/selfie_segmenter_landscape.tflite" |
| 176 | + try: |
| 177 | + print("Downloading model...") |
| 178 | + with urllib.request.urlopen(url) as response: |
| 179 | + with open(filename, 'wb') as out_file: |
| 180 | + out_file.write(response.read()) |
| 181 | + except Exception: |
| 182 | + print("Cannot download MediaPipe model") |
| 183 | + |
| 184 | + self.classifier = ImageSegmenter(self.real_width, self.real_height) |
133 | 185 |
|
134 | 186 |
|
135 | 187 | def resize_image(self, img, keep_aspect): |
@@ -264,7 +316,7 @@ def next_frame(): |
264 | 316 | (self.width, self.height)) |
265 | 317 |
|
266 | 318 | def compose_frame(self, frame): |
267 | | - mask = copy.copy(self.classifier.process(frame).segmentation_mask) |
| 319 | + mask = copy.copy(self.classifier.segment(frame)) |
268 | 320 |
|
269 | 321 | if self.threshold < 1: |
270 | 322 | cv2.threshold(mask, self.threshold, 1, cv2.THRESH_BINARY, dst=mask) |
@@ -800,6 +852,7 @@ def sigint_handler(cam, signal, frame): |
800 | 852 |
|
801 | 853 | def sigquit_handler(cam, signal, frame): |
802 | 854 | print("\nKilling fake cam process") |
| 855 | + cam.classifier.close() |
803 | 856 | sys.exit(0) |
804 | 857 |
|
805 | 858 |
|
|
0 commit comments