diff --git a/detection/blazeface_paddle/test_blazeface.py b/detection/blazeface_paddle/test_blazeface.py index fa8a8f103..3c72f3062 100644 --- a/detection/blazeface_paddle/test_blazeface.py +++ b/detection/blazeface_paddle/test_blazeface.py @@ -17,13 +17,11 @@ import requests import logging import imghdr -import pickle import tarfile from functools import partial import cv2 import numpy as np -from sklearn.metrics.pairwise import cosine_similarity from tqdm import tqdm from prettytable import PrettyTable from PIL import Image, ImageDraw, ImageFont @@ -42,78 +40,49 @@ def str2bool(v): parser = argparse.ArgumentParser(add_help=add_help) - parser.add_argument( - "--det_model", - type=str, - default="BlazeFace", - help="The detection model.") - parser.add_argument( - "--use_gpu", - type=str2bool, - default=True, - help="Whether use GPU to predict. Default by True.") - parser.add_argument( - "--enable_mkldnn", - type=str2bool, - default=True, - help="Whether use MKLDNN to predict, valid only when --use_gpu is False. Default by False." - ) - parser.add_argument( - "--cpu_threads", - type=int, - default=1, - help="The num of threads with CPU, valid only when --use_gpu is False. Default by 1." - ) - parser.add_argument( - "--input", - type=str, - help="The path or directory of image(s) or video to be predicted.") - parser.add_argument( - "--output", type=str, default="./output/", help="The directory of prediction result.") - parser.add_argument( - "--det_thresh", - type=float, - default=0.8, - help="The threshold of detection postprocess. Default by 0.8.") + parser.add_argument("--det_model", type=str, default="BlazeFace", help="The detection model.") + parser.add_argument("--use_gpu", type=str2bool, default=True, help="Whether use GPU to predict.") + parser.add_argument("--enable_mkldnn", type=str2bool, default=True, + help="Whether use MKLDNN to predict, valid only when --use_gpu is False.") + parser.add_argument("--cpu_threads", type=int, default=1, + help="The num of threads with CPU, valid only when --use_gpu is False.") + parser.add_argument("--input", type=str, help="The path or directory of image(s) or video to be predicted.") + parser.add_argument("--output", type=str, default="./output/", help="The directory of prediction result.") + parser.add_argument("--det_thresh", type=float, default=0.8, help="The threshold of detection postprocess.") return parser def print_config(args): - args = vars(args) table = PrettyTable(['Param', 'Value']) - for param in args: - table.add_row([param, args[param]]) + for param, value in vars(args).items(): + table.add_row([param, value]) width = len(str(table).split("\n")[0]) - print("{}".format("-" * width)) + print("-" * width) print("PaddleFace".center(width)) print(table) print("Powered by PaddlePaddle!".rjust(width)) - print("{}".format("-" * width)) + print("-" * width) def download_with_progressbar(url, save_path): - """Download from url with progressbar. - """ + """Download from url with progressbar.""" if os.path.isfile(save_path): os.remove(save_path) response = requests.get(url, stream=True) total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1 Kibibyte + block_size = 1024 progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) with open(save_path, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) file.write(data) progress_bar.close() - if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile( - save_path): - raise Exception( - f"Something went wrong while downloading model/image from {url}") + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes or not os.path.isfile(save_path): + raise Exception(f"Something went wrong while downloading model/image from {url}") def check_model_file(model): - """Check the model files exist and download and untar when no exist. - """ + """Check the model files exist and download and untar when no exist.""" model_map = { "ArcFace": "arcface_iresnet50_v1.0_infer", "BlazeFace": "blazeface_fpn_ssh_1000e_v1.0_infer", @@ -123,82 +92,77 @@ def check_model_file(model): if os.path.isdir(model): model_file_path = os.path.join(model, "inference.pdmodel") params_file_path = os.path.join(model, "inference.pdiparams") - if not os.path.exists(model_file_path) or not os.path.exists( - params_file_path): + if not os.path.exists(model_file_path) or not os.path.exists(params_file_path): raise Exception( - f"The specifed model directory error. The drectory must include 'inference.pdmodel' and 'inference.pdiparams'." - ) - + f"The specified model directory error. The directory must include 'inference.pdmodel' and 'inference.pdiparams'.") elif model in model_map: - storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, - model) + storage_directory = partial(os.path.join, BASE_INFERENCE_MODEL_DIR, model) url = BASE_DOWNLOAD_URL.format(model_map[model]) - tar_file_name_list = [ - "inference.pdiparams", "inference.pdiparams.info", - "inference.pdmodel" - ] model_file_path = storage_directory("inference.pdmodel") params_file_path = storage_directory("inference.pdiparams") - if not os.path.exists(model_file_path) or not os.path.exists( - params_file_path): + + if not os.path.exists(model_file_path) or not os.path.exists(params_file_path): tmp_path = storage_directory(url.split("/")[-1]) logging.info(f"Download {url} to {tmp_path}") os.makedirs(storage_directory(), exist_ok=True) download_with_progressbar(url, tmp_path) + + tar_file_name_list = ["inference.pdiparams", "inference.pdiparams.info", "inference.pdmodel"] with tarfile.open(tmp_path, "r") as tarObj: for member in tarObj.getmembers(): - filename = None for tar_file_name in tar_file_name_list: if tar_file_name in member.name: - filename = tar_file_name - if filename is None: - continue - file = tarObj.extractfile(member) - with open(storage_directory(filename), "wb") as f: - f.write(file.read()) + file = tarObj.extractfile(member) + with open(storage_directory(tar_file_name), "wb") as f: + f.write(file.read()) + break os.remove(tmp_path) - if not os.path.exists(model_file_path) or not os.path.exists( - params_file_path): - raise Exception( - f"Something went wrong while downloading and unzip the model[{model}] files!" - ) + + if not os.path.exists(model_file_path) or not os.path.exists(params_file_path): + raise Exception(f"Something went wrong while downloading and unzip the model[{model}] files!") else: raise Exception( - f"The specifed model name error. Support 'BlazeFace' for detection. And support local directory that include model files ('inference.pdmodel' and 'inference.pdiparams')." - ) + f"The specified model name error. Support 'BlazeFace' for detection. " + f"And support local directory that include model files ('inference.pdmodel' and 'inference.pdiparams').") return model_file_path, params_file_path def normalize_image(img, scale=None, mean=None, std=None, order='chw'): + """Optimized image normalization with vectorized operations.""" if isinstance(scale, str): scale = eval(scale) scale = np.float32(scale if scale is not None else 1.0 / 255.0) - mean = mean if mean is not None else [0.485, 0.456, 0.406] - std = std if std is not None else [0.229, 0.224, 0.225] + mean = np.array(mean if mean is not None else [0.485, 0.456, 0.406], dtype=np.float32) + std = np.array(std if std is not None else [0.229, 0.224, 0.225], dtype=np.float32) - shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) - mean = np.array(mean).reshape(shape).astype('float32') - std = np.array(std).reshape(shape).astype('float32') + if order == 'chw': + mean = mean.reshape(3, 1, 1) + std = std.reshape(3, 1, 1) + else: + mean = mean.reshape(1, 1, 3) + std = std.reshape(1, 1, 3) if isinstance(img, Image.Image): img = np.array(img) assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage" - return (img.astype('float32') * scale - mean) / std + return (img.astype(np.float32) * scale - mean) / std def to_CHW_image(img): + """Convert HWC image to CHW format.""" if isinstance(img, Image.Image): img = np.array(img) return img.transpose((2, 0, 1)) class ColorMap(object): + """Optimized color map with precomputed colors.""" def __init__(self, num): super().__init__() - self.get_color_map_list(num) + self.color_list = self._get_color_map_list(num) self.color_map = {} self.ptr = 0 @@ -206,63 +170,61 @@ def __getitem__(self, key): return self.color_map[key] def update(self, keys): + """Update color map with new keys.""" for key in keys: if key not in self.color_map: i = self.ptr % len(self.color_list) self.color_map[key] = self.color_list[i] self.ptr += 1 - def get_color_map_list(self, num_classes): - color_map = num_classes * [0, 0, 0] - for i in range(0, num_classes): - j = 0 + @staticmethod + def _get_color_map_list(num_classes): + """Generate color map using bit manipulation.""" + color_map = np.zeros((num_classes, 3), dtype=np.int32) + for i in range(num_classes): lab = i + j = 0 while lab: - color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) - color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) - color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) + color_map[i, 0] |= ((lab >> 0) & 1) << (7 - j) + color_map[i, 1] |= ((lab >> 1) & 1) << (7 - j) + color_map[i, 2] |= ((lab >> 2) & 1) << (7 - j) j += 1 lab >>= 3 - self.color_list = [ - color_map[i:i + 3] for i in range(0, len(color_map), 3) - ] + return [tuple(color_map[i]) for i in range(num_classes)] class ImageReader(object): + """Optimized image reader with better error handling.""" + SUPPORTED_TYPES = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'} + def __init__(self, inputs): super().__init__() self.idx = 0 + if isinstance(inputs, np.ndarray): self.image_list = [inputs] else: - imgtype_list = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'} - self.image_list = [] - if os.path.isfile(inputs): - if imghdr.what(inputs) not in imgtype_list: - raise Exception( - f"Error type of input path, only support: {imgtype_list}" - ) - self.image_list.append(inputs) - elif os.path.isdir(inputs): - tmp_file_list = os.listdir(inputs) - warn_tag = False - for file_name in tmp_file_list: - file_path = os.path.join(inputs, file_name) - if not os.path.isfile(file_path): - warn_tag = True - continue - if imghdr.what(file_path) in imgtype_list: - self.image_list.append(file_path) - else: - warn_tag = True - if warn_tag: - logging.warning( - f"The directory of input contine directory or not supported file type, only support: {imgtype_list}" - ) - else: - raise Exception( - f"The file of input path not exist! Please check input: {inputs}" - ) + self.image_list = self._collect_images(inputs) + + def _collect_images(self, inputs): + """Collect valid image paths.""" + image_list = [] + + if os.path.isfile(inputs): + if imghdr.what(inputs) not in self.SUPPORTED_TYPES: + raise Exception(f"Error type of input path, only support: {self.SUPPORTED_TYPES}") + image_list.append(inputs) + elif os.path.isdir(inputs): + for file_name in os.listdir(inputs): + file_path = os.path.join(inputs, file_name) + if os.path.isfile(file_path) and imghdr.what(file_path) in self.SUPPORTED_TYPES: + image_list.append(file_path) + if not image_list: + logging.warning(f"No supported images found in directory: {inputs}") + else: + raise Exception(f"The file of input path not exist! Please check input: {inputs}") + + return image_list def __iter__(self): return self @@ -272,49 +234,49 @@ def __next__(self): raise StopIteration data = self.image_list[self.idx] + self.idx += 1 + if isinstance(data, np.ndarray): - self.idx += 1 return data, "tmp.png" - path = data - _, file_name = os.path.split(path) - img = cv2.imread(path) + + img = cv2.imread(data) if img is None: - logging.warning(f"Error in reading image: {path}! Ignored.") - self.idx += 1 + logging.warning(f"Error in reading image: {data}! Skipping.") return self.__next__() - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - self.idx += 1 - return img, file_name + + _, file_name = os.path.split(data) + return cv2.cvtColor(img, cv2.COLOR_BGR2RGB), file_name def __len__(self): return len(self.image_list) class VideoReader(object): + """Optimized video reader.""" + SUPPORTED_TYPES = {"mp4"} + def __init__(self, inputs): super().__init__() - videotype_list = {"mp4"} - if os.path.splitext(inputs)[-1][1:] not in videotype_list: - raise Exception( - f"The input file is not supported, only support: {videotype_list}" - ) + ext = os.path.splitext(inputs)[-1][1:] + if ext not in self.SUPPORTED_TYPES: + raise Exception(f"The input file is not supported, only support: {self.SUPPORTED_TYPES}") if not os.path.isfile(inputs): - raise Exception( - f"The file of input path not exist! Please check input: {inputs}" - ) + raise Exception(f"The file of input path not exist! Please check input: {inputs}") + self.capture = cv2.VideoCapture(inputs) self.file_name = os.path.split(inputs)[-1] def get_info(self): - info = {} + """Get video information.""" width = int(self.capture.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(self.capture.get(cv2.CAP_PROP_FRAME_HEIGHT)) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') - info["file_name"] = self.file_name - info["fps"] = 30 - info["shape"] = (width, height) - info["fourcc"] = cv2.VideoWriter_fourcc(* 'mp4v') - return info + + return { + "file_name": self.file_name, + "fps": 30, + "shape": (width, height), + "fourcc": cv2.VideoWriter_fourcc(*'mp4v') + } def __iter__(self): return self @@ -327,36 +289,34 @@ def __next__(self): class ImageWriter(object): + """Optimized image writer.""" def __init__(self, output_dir): super().__init__() if output_dir is None: - raise Exception( - "Please specify the directory of saving prediction results by --output." - ) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + raise Exception("Please specify the directory of saving prediction results by --output.") + os.makedirs(output_dir, exist_ok=True) self.output_dir = output_dir def write(self, image, file_name): + """Write image to disk.""" path = os.path.join(self.output_dir, file_name) cv2.imwrite(path, cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) class VideoWriter(object): + """Optimized video writer.""" def __init__(self, output_dir, video_info): super().__init__() if output_dir is None: - raise Exception( - "Please specify the directory of saving prediction results by --output." - ) - if not os.path.exists(output_dir): - os.makedirs(output_dir) + raise Exception("Please specify the directory of saving prediction results by --output.") + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, video_info["file_name"]) - fourcc = cv2.VideoWriter_fourcc(* 'mp4v') self.writer = cv2.VideoWriter(output_path, video_info["fourcc"], - video_info["fps"], video_info["shape"]) + video_info["fps"], video_info["shape"]) def write(self, frame, file_name): + """Write frame to video.""" self.writer.write(frame) def __del__(self): @@ -365,38 +325,39 @@ def __del__(self): class BasePredictor(object): + """Base predictor with optimized initialization.""" def __init__(self, predictor_config): super().__init__() self.predictor_config = predictor_config - self.predictor, self.input_names, self.output_names = self.load_predictor( + self.predictor, self.input_names, self.output_names = self._load_predictor( predictor_config["model_file"], predictor_config["params_file"]) - def load_predictor(self, model_file, params_file): + def _load_predictor(self, model_file, params_file): + """Load predictor with optimized configuration.""" config = Config(model_file, params_file) + if self.predictor_config["use_gpu"]: config.enable_use_gpu(200, 0) config.switch_ir_optim(True) else: config.disable_gpu() - config.set_cpu_math_library_num_threads(self.predictor_config[ - "cpu_threads"]) + config.set_cpu_math_library_num_threads(self.predictor_config["cpu_threads"]) if self.predictor_config["enable_mkldnn"]: try: - # cache 10 different shapes for mkldnn to avoid memory leak config.set_mkldnn_cache_capacity(10) config.enable_mkldnn() - except Exception as e: - logging.error( - "The current environment does not support `mkldnn`, so disable mkldnn." - ) + except Exception: + logging.error("The current environment does not support `mkldnn`, so disable mkldnn.") + config.disable_glog_info() config.enable_memory_optim() - # use zero copy config.switch_use_feed_fetch_ops(False) + predictor = create_predictor(config) input_names = predictor.get_input_names() output_names = predictor.get_output_names() + return predictor, input_names, output_names def preprocess(self): @@ -410,53 +371,65 @@ def predict(self, img): class Detector(BasePredictor): + """Optimized detector with cached values.""" def __init__(self, det_config, predictor_config): super().__init__(predictor_config) self.det_config = det_config - self.target_size = self.det_config["target_size"] - self.thresh = self.det_config["thresh"] + self.target_size = det_config["target_size"] + self.thresh = det_config["thresh"] + + # Cache normalization parameters + self.norm_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 3) + self.norm_std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 3) + self.scale = 1.0 / 255.0 + + # Pre-compute target dimensions + self.resize_h, self.resize_w = self.target_size def preprocess(self, img): - resize_h, resize_w = self.target_size + """Optimized preprocessing with vectorized operations.""" img_shape = img.shape - img_scale_x = resize_w / img_shape[1] - img_scale_y = resize_h / img_shape[0] - img = cv2.resize( - img, None, None, fx=img_scale_x, fy=img_scale_y, interpolation=1) - img = normalize_image( - img, - scale=1. / 255., - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225], - order='hwc') - img_info = {} - img_info["im_shape"] = np.array( - img.shape[:2], dtype=np.float32)[np.newaxis, :] - img_info["scale_factor"] = np.array( - [img_scale_y, img_scale_x], dtype=np.float32)[np.newaxis, :] - - img = img.transpose((2, 0, 1)).copy() - img_info["image"] = img[np.newaxis, :, :, :] + img_scale_x = self.resize_w / img_shape[1] + img_scale_y = self.resize_h / img_shape[0] + + # Single resize operation + img = cv2.resize(img, (self.resize_w, self.resize_h), interpolation=cv2.INTER_LINEAR) + + # Vectorized normalization + img = (img.astype(np.float32) * self.scale - self.norm_mean) / self.norm_std + + # Prepare output dictionary + img_info = { + "im_shape": np.array([[img.shape[0], img.shape[1]]], dtype=np.float32), + "scale_factor": np.array([[img_scale_y, img_scale_x]], dtype=np.float32), + "image": img.transpose((2, 0, 1))[np.newaxis, :, :, :] + } + return img_info def postprocess(self, np_boxes): - expect_boxes = (np_boxes[:, 1] > self.thresh) & (np_boxes[:, 0] > -1) - return np_boxes[expect_boxes, :] + """Optimized postprocessing with vectorized filtering.""" + mask = (np_boxes[:, 1] > self.thresh) & (np_boxes[:, 0] > -1) + return np_boxes[mask] def predict(self, img): + """Run detection inference.""" inputs = self.preprocess(img) + for input_name in self.input_names: input_tensor = self.predictor.get_input_handle(input_name) input_tensor.copy_from_cpu(inputs[input_name]) + self.predictor.run() + output_tensor = self.predictor.get_output_handle(self.output_names[0]) np_boxes = output_tensor.copy_to_cpu() - # boxes_num = self.detector.get_output_handle(self.detector_output_names[1]) - # np_boxes_num = boxes_num.copy_to_cpu() - box_list = self.postprocess(np_boxes) - return box_list + + return self.postprocess(np_boxes) + class FaceDetector(object): + """Optimized face detector with better resource management.""" def __init__(self, args, print_info=True): super().__init__() if print_info: @@ -467,25 +440,36 @@ def __init__(self, args, print_info=True): "SourceHanSansCN-Medium.otf") self.args = args + # Initialize detector + model_file_path, params_file_path = check_model_file(args.det_model) + predictor_config = { "use_gpu": args.use_gpu, "enable_mkldnn": args.enable_mkldnn, - "cpu_threads": args.cpu_threads + "cpu_threads": args.cpu_threads, + "model_file": model_file_path, + "params_file": params_file_path } - - model_file_path, params_file_path = check_model_file( - args.det_model) + det_config = {"thresh": args.det_thresh, "target_size": [640, 640]} - predictor_config["model_file"] = model_file_path - predictor_config["params_file"] = params_file_path self.det_predictor = Detector(det_config, predictor_config) self.color_map = ColorMap(100) + + # Cache font objects for different sizes + self.font_cache = {} def preprocess(self, img): - img = img.astype(np.float32, copy=False) - return img + """Lightweight preprocessing - just ensure float32.""" + return img.astype(np.float32, copy=False) + + def _get_font(self, size): + """Get cached font or create new one.""" + if size not in self.font_cache: + self.font_cache[size] = ImageFont.truetype(self.font_path, size) + return self.font_cache[size] def draw(self, img, box_list, labels): + """Optimized drawing with cached fonts and colors.""" self.color_map.update(labels) im = Image.fromarray(img) draw = ImageDraw.Draw(im) @@ -493,43 +477,40 @@ def draw(self, img, box_list, labels): for i, dt in enumerate(box_list): bbox, score = dt[2:], dt[1] label = labels[i] - color = tuple(self.color_map[label]) + color = self.color_map[label] xmin, ymin, xmax, ymax = bbox + # Get appropriate font size font_size = max(int((xmax - xmin) // 6), 10) - font = ImageFont.truetype(self.font_path, font_size) + font = self._get_font(font_size) - text = "{} {:.4f}".format(label, score) + # Prepare text + text = f"{label} {score:.4f}" th = sum(font.getmetrics()) tw = font.getsize(text)[0] start_y = max(0, ymin - th) - draw.rectangle( - [(xmin, start_y), (xmin + tw + 1, start_y + th)], fill=color) - draw.text( - (xmin + 1, start_y), - text, - fill=(255, 255, 255), - font=font, - anchor="la") - draw.rectangle( - [(xmin, ymin), (xmax, ymax)], width=2, outline=color) + # Draw text background and text + draw.rectangle([(xmin, start_y), (xmin + tw + 1, start_y + th)], fill=color) + draw.text((xmin + 1, start_y), text, fill=(255, 255, 255), font=font, anchor="la") + + # Draw bounding box + draw.rectangle([(xmin, ymin), (xmax, ymax)], width=2, outline=color) + return np.array(im) def predict_np_img(self, img): + """Predict on numpy image.""" input_img = self.preprocess(img) - box_list = None - np_feature = None - if hasattr(self, "det_predictor"): - box_list = self.det_predictor.predict(input_img) - return box_list, np_feature + box_list = self.det_predictor.predict(input_img) + return box_list, None def init_reader_writer(self, input_data): + """Initialize appropriate reader and writer.""" if isinstance(input_data, np.ndarray): self.input_reader = ImageReader(input_data) - if hasattr(self, "det_predictor"): - self.output_writer = ImageWriter(self.args.output) + self.output_writer = ImageWriter(self.args.output) elif isinstance(input_data, str): if input_data.endswith('mp4'): self.input_reader = VideoReader(input_data) @@ -537,19 +518,18 @@ def init_reader_writer(self, input_data): self.output_writer = VideoWriter(self.args.output, info) else: self.input_reader = ImageReader(input_data) - if hasattr(self, "det_predictor"): - self.output_writer = ImageWriter(self.args.output) + self.output_writer = ImageWriter(self.args.output) else: raise Exception( - f"The input data error. Only support path of image or video(.mp4) and dirctory that include images." - ) + f"The input data error. Only support path of image or video(.mp4) and directory that include images.") def predict(self, input_data, print_info=False): """Predict input_data. Args: - input_data (str | NumPy.array): The path of image, or the derectory including images, or the image data in NumPy.array format. - print_info (bool, optional): Wheather to print the prediction results. Defaults to False. + input_data (str | NumPy.array): The path of image, or the directory including images, + or the image data in NumPy.array format. + print_info (bool, optional): Whether to print the prediction results. Defaults to False. Yields: dict: { @@ -559,33 +539,38 @@ def predict(self, input_data, print_info=False): } """ self.init_reader_writer(input_data) + for img, file_name in self.input_reader: if img is None: logging.warning(f"Error in reading img {file_name}! Ignored.") continue + box_list, np_feature = self.predict_np_img(img) labels = ["face"] * len(box_list) - if box_list is not None: + + if box_list is not None and len(box_list) > 0: result = self.draw(img, box_list, labels=labels) self.output_writer.write(result, file_name) + if print_info: logging.info(f"File: {file_name}, predict label(s): {labels}") + yield { "box_list": box_list, "features": np_feature, "labels": labels } - logging.info(f"Predict complete!") + + logging.info("Predict complete!") -# for CLI def main(args=None): + """CLI entry point.""" logging.basicConfig(level=logging.INFO) - args = parser().parse_args() predictor = FaceDetector(args) - res = predictor.predict(args.input, print_info=True) - for _ in res: + + for _ in predictor.predict(args.input, print_info=True): pass