44import cv2
55import numpy as np
66from lada .lib import Mask , Box , Image , Detection , Detections , DETECTION_CLASSES
7- from lada .lib import mask_utils
7+ from lada .lib import box_utils
88from lada .centerface .centerface import CenterFace
99
10- def scale_box (box , mask_scale = 1.0 ) -> Box :
11- s = mask_scale - 1.0
12- t , l , b , r = box
13- w , h = r - l + 1 , b - t + 1
14- t -= h * s
15- b += h * s
16- l -= w * s
17- r += w * s
18- return int (t ), int (l ), int (b ), int (r )
19-
2010def convert_to_boxes (dets ) -> list [Box ]:
2111 boxes = []
2212 for i , det in enumerate (dets ):
@@ -50,17 +40,15 @@ def create_mask(frame: Image, box: Box) -> Mask:
5040
5141 return mask
5242
53- def get_nsfw_frame (dets : list [Box ], frame , random_extend_masks : bool , mask_scale : float ) -> Detections | None :
43+ def get_nsfw_frame (dets : list [Box ], frame : Image , random_extend_masks : bool ) -> Detections | None :
5444 if len (dets ) == 0 :
5545 return None
5646 detections = []
5747 for box in dets :
58- box = scale_box (box , mask_scale )
59- mask = create_mask (frame , box )
6048
6149 if random_extend_masks :
62- mask = mask_utils . apply_random_mask_extensions ( mask )
63- box = mask_utils . get_box ( mask )
50+ box = box_utils . random_scale_box ( frame , box , scale_range = ( 1.2 , 1.5 ) )
51+ mask = create_mask ( frame , box )
6452
6553 t , l , b , r = box
6654 width , height = r - l + 1 , b - t + 1
@@ -72,14 +60,13 @@ def get_nsfw_frame(dets: list[Box], frame, random_extend_masks: bool, mask_scale
7260 return Detections (frame , detections )
7361
7462class FaceDetector :
75- def __init__ (self , model : CenterFace , random_extend_masks = False , conf = 0.2 , mask_scale = 1.3 ):
63+ def __init__ (self , model : CenterFace , random_extend_masks = False , conf = 0.2 ):
7664 self .model = model
7765 self .random_extend_masks = random_extend_masks
7866 self .conf = conf
79- self .mask_scale = mask_scale
8067
8168 def detect (self , file_path : str ) -> Detections | None :
8269 image = cv2 .imread (file_path , cv2 .IMREAD_COLOR_RGB )
8370 dets , _ = self .model (image , threshold = self .conf )
8471 dets_boxes = convert_to_boxes (dets )
85- return get_nsfw_frame (dets_boxes , cv2 .cvtColor (image , cv2 .COLOR_RGB2BGR ), random_extend_masks = self .random_extend_masks , mask_scale = self . mask_scale )
72+ return get_nsfw_frame (dets_boxes , cv2 .cvtColor (image , cv2 .COLOR_RGB2BGR ), random_extend_masks = self .random_extend_masks )
0 commit comments