diff --git a/paddleocr.py b/paddleocr.py index bdf367cf1a4..b427600a7fe 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -727,9 +727,10 @@ def ocr( - The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified. """ assert isinstance(img, (np.ndarray, list, str, bytes)) - if isinstance(img, list) and det == True: - logger.error("When input a list of images, det must be false") - exit(0) + # 支持多张图片的批处理,移除原有限制 + # if isinstance(img, list) and det == True: + # logger.error("When input a list of images, det must be false") + # exit(0) if cls == True and self.use_angle_cls == False: logger.warning( "Since the angle classifier is not initialized, it will not be used during the forward process" @@ -753,41 +754,83 @@ def preprocess_image(_image): _image = binarize_img(_image) return _image - if det and rec: + if det and rec: # 默认 注意我只测试了这一判断情况 ocr_res = [] for img in imgs: - img = preprocess_image(img) - dt_boxes, rec_res, _ = self.__call__(img, cls, slice) - if not dt_boxes and not rec_res: - ocr_res.append(None) - continue - tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] - ocr_res.append(tmp_res) + # 确保img是单张图片(ndarray),而不是图片列表 + if isinstance(img, list): + # 如果img是列表,说明是多张图片,需要预处理后再组成列表传入 + processed_imgs = [preprocess_image(single_img) for single_img in img] + batch_results = self.__call__(processed_imgs, cls, slice) + # batch_results是一个列表,每个元素是(dt_boxes, rec_res, time_dict) + for dt_boxes, rec_res, _ in batch_results: + if not dt_boxes and not rec_res: + ocr_res.append(None) + continue + tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] + ocr_res.append(tmp_res) + else: + # 单张图片处理 + img = preprocess_image(img) + dt_boxes, rec_res, _ = self.__call__(img, cls, slice) + if not dt_boxes and not rec_res: + ocr_res.append(None) + continue + tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] + ocr_res.append(tmp_res) return ocr_res elif det and not rec: ocr_res = [] for img in imgs: - img = preprocess_image(img) - dt_boxes, elapse = self.text_detector(img) - if dt_boxes.size == 0: - ocr_res.append(None) - continue - tmp_res = [box.tolist() for box in dt_boxes] - ocr_res.append(tmp_res) + # 确保img是单张图片(ndarray),而不是图片列表 + if isinstance(img, list): + # 如果img是列表,说明是多张图片,需要预处理后再传入 + processed_imgs = [preprocess_image(single_img) for single_img in img] + # 对于仅检测模式,需要分别处理每张图片 + for processed_img in processed_imgs: + dt_boxes, elapse = self.text_detector(processed_img) + if dt_boxes.size == 0: + ocr_res.append(None) + continue + tmp_res = [box.tolist() for box in dt_boxes] + ocr_res.append(tmp_res) + else: + # 单张图片处理 + img = preprocess_image(img) + dt_boxes, elapse = self.text_detector(img) + if dt_boxes.size == 0: + ocr_res.append(None) + continue + tmp_res = [box.tolist() for box in dt_boxes] + ocr_res.append(tmp_res) return ocr_res else: ocr_res = [] cls_res = [] for img in imgs: - if not isinstance(img, list): + # 确保img是单张图片(ndarray),而不是图片列表 + if isinstance(img, list): + # 如果img是列表,说明是多张图片,需要预处理后再传入 + processed_imgs = [preprocess_image(single_img) for single_img in img] + # 对于仅识别模式,需要分别处理每张图片 + for processed_img in processed_imgs: + processed_img = [processed_img] + if self.use_angle_cls and cls: + processed_img, cls_res_tmp, elapse = self.text_classifier(processed_img) + if not rec: + cls_res.append(cls_res_tmp) + rec_res, elapse = self.text_recognizer(processed_img) + ocr_res.append(rec_res) + else: + # 单张图片处理 img = preprocess_image(img) img = [img] - if self.use_angle_cls and cls: - img, cls_res_tmp, elapse = self.text_classifier(img) - if not rec: - cls_res.append(cls_res_tmp) - rec_res, elapse = self.text_recognizer(img) - ocr_res.append(rec_res) + if self.use_angle_cls and cls: + img, cls_res_tmp, elapse = self.text_classifier(img) + if not rec: + cls_res.append(cls_res_tmp) + rec_res, elapse = self.text_recognizer(img) + ocr_res.append(rec_res) if not rec: return cls_res return ocr_res diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 02bce45b737..67889be41a5 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -227,6 +227,144 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape): dt_boxes = np.array(dt_boxes_new) return dt_boxes + def predict_batch(self, img_list, batch_size=None): + """ + 批量预测多张图片 + Args: + img_list: 图片列表 + batch_size: 批处理大小,如果为None则处理所有图片 + Returns: + batch_dt_boxes: 批量检测框结果列表 + elapse: 总耗时 + """ + if batch_size is None: + batch_size = len(img_list) + + all_dt_boxes = [] + total_elapse = 0 + + # 分批处理 + for i in range(0, len(img_list), batch_size): + batch_imgs = img_list[i:i+batch_size] + batch_dt_boxes, elapse = self._predict_single_batch(batch_imgs) + all_dt_boxes.extend(batch_dt_boxes) + total_elapse += elapse + + return all_dt_boxes, total_elapse + + def _predict_single_batch(self, img_list): + """ + 处理单个批次的图片 + """ + batch_size = len(img_list) + ori_imgs = [img.copy() for img in img_list] + + st = time.time() + + if self.args.benchmark: + self.autolog.times.start() + + # 批量预处理 + batch_data = [] + batch_shape_lists = [] + + for img in img_list: + data = {"image": img} + data = transform(data, self.preprocess_op) + processed_img, shape_list = data + if processed_img is not None: + batch_data.append(processed_img) + batch_shape_lists.append(shape_list) + + if not batch_data: + return [None] * batch_size, 0 + + # 将批量数据组合成一个batch + #batch_imgs = np.array(batch_data) # Error processing batch ['IMG_20241230_162405.HEIC', 'IMG_20241228_171324.HEIC']: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 3) + inhomogeneous part. # {ndarray::(3,960,704)}{ndarray:(3,704,960)} + # 统一批量数据的尺寸 - 填充到最大尺寸 + if len(batch_data) > 1: + # 找到最大的高度和宽度 + max_h = max(img.shape[1] for img in batch_data) # shape: (C, H, W) + max_w = max(img.shape[2] for img in batch_data) + + # 将所有图片填充到相同尺寸 + padded_imgs = [] + for img in batch_data: + c, h, w = img.shape + # 创建填充后的图片 + padded_img = np.zeros((c, max_h, max_w), dtype=img.dtype) + padded_img[:, :h, :w] = img # 左上角对齐 + padded_imgs.append(padded_img) + + batch_imgs = np.array(padded_imgs) + else: + batch_imgs = np.array(batch_data) + + batch_shape_lists = np.array(batch_shape_lists) + + if self.args.benchmark: + self.autolog.times.stamp() + + + # 批量推理 + if self.use_onnx: + input_dict = {} + input_dict[self.input_tensor.name] = batch_imgs + outputs = self.predictor.run(self.output_tensors, input_dict) + else: + self.input_tensor.copy_from_cpu(batch_imgs) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + if self.args.benchmark: + self.autolog.times.stamp() + + # 批量后处理 + batch_dt_boxes = [] + + for batch_idx in range(len(batch_data)): + # 为每张图片提取对应的输出 + single_outputs = {} + if self.det_algorithm == "EAST": + single_outputs["f_geo"] = outputs[0][batch_idx:batch_idx+1] + single_outputs["f_score"] = outputs[1][batch_idx:batch_idx+1] + elif self.det_algorithm == "SAST": + single_outputs["f_border"] = outputs[0][batch_idx:batch_idx+1] + single_outputs["f_score"] = outputs[1][batch_idx:batch_idx+1] + single_outputs["f_tco"] = outputs[2][batch_idx:batch_idx+1] + single_outputs["f_tvo"] = outputs[3][batch_idx:batch_idx+1] + elif self.det_algorithm in ["DB", "PSE", "DB++"]: + single_outputs["maps"] = outputs[0][batch_idx:batch_idx+1] + elif self.det_algorithm == "FCE": + for i, output in enumerate(outputs): + single_outputs["level_{}".format(i)] = output[batch_idx:batch_idx+1] + elif self.det_algorithm == "CT": + single_outputs["maps"] = outputs[0][batch_idx:batch_idx+1] + single_outputs["score"] = outputs[1][batch_idx:batch_idx+1] + else: + raise NotImplementedError + + # 单张图片后处理 + single_shape_list = batch_shape_lists[batch_idx:batch_idx+1] + post_result = self.postprocess_op(single_outputs, single_shape_list) + dt_boxes = post_result[0]["points"] + + if self.args.det_box_type == "poly": + dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[batch_idx].shape) + else: # 这个分支 + dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[batch_idx].shape) + + batch_dt_boxes.append(dt_boxes) + + if self.args.benchmark: + self.autolog.times.end(stamp=True) + + et = time.time() + return batch_dt_boxes, et - st + + def predict(self, img): ori_im = img.copy() data = {"image": img} @@ -293,7 +431,21 @@ def predict(self, img): et = time.time() return dt_boxes, et - st - def __call__(self, img, use_slice=False): + def __call__(self, img, use_slice=False, batch_size=None): + """ + 文本检测调用接口 + Args: + img: 单张图片或图片列表 + use_slice: 是否使用切片处理 + batch_size: 批处理大小(仅当img为列表时有效) + Returns: + 检测框结果和耗时 + """ + # 支持批处理 + if isinstance(img, list): + return self.predict_batch(img, batch_size) + + # 原有的单张图片处理逻辑 # For image like poster with one side much greater than the other side, # splitting recursively and processing with overlap to enhance performance. MIN_BOUND_DISTANCE = 50 diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index bcb0758eb66..90ad099e597 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -72,8 +72,110 @@ def draw_crop_rec_res(self, output_dir, img_crop_list, rec_res): ) logger.debug(f"{bno}, {rec_res[bno]}") self.crop_image_res_index += bbox_num - - def __call__(self, img, cls=True, slice={}): + + def _batch_process(self, img_list, cls=True, slice={}, batch_size=None): + """ + 批量处理多张图片 + Args: + img_list: 图片列表 + cls: 是否使用角度分类 + slice: 切片参数(批处理时不支持) + batch_size: 批处理大小 + Returns: + 批量处理结果 + """ + if slice: + logger.warning("批处理模式不支持slice参数,将忽略该参数") + + if batch_size is None: + batch_size = min(8, len(img_list)) # 默认批处理大小 + + all_results = [] + total_time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0} + + start_all = time.time() + + # 批量文本检测 + logger.info(f"开始批量检测 {len(img_list)} 张图片,批处理大小: {batch_size}") + batch_dt_boxes, det_elapse = self.text_detector(img_list, batch_size=batch_size) + total_time_dict["det"] = det_elapse + + # 为每张图片进行后续处理 + for img_idx, (img, dt_boxes) in enumerate(zip(img_list, batch_dt_boxes)): + if dt_boxes is None or len(dt_boxes) == 0: + all_results.append((None, None, {"det": 0, "rec": 0, "cls": 0, "all": 0})) + continue + + ori_im = img.copy() + img_crop_list = [] + + dt_boxes = sorted_boxes(dt_boxes) + + # 裁剪文本区域 + for bno in range(len(dt_boxes)): + tmp_box = copy.deepcopy(dt_boxes[bno]) + if self.args.det_box_type == "quad": + img_crop = get_rotate_crop_image(ori_im, tmp_box) + else: + img_crop = get_minarea_rect_crop(ori_im, tmp_box) + img_crop_list.append(img_crop) + + # 角度分类 + cls_time = 0 + if self.use_angle_cls and cls and len(img_crop_list) > 0: + cls_start = time.time() + img_crop_list, angle_list, elapse = self.text_classifier(img_crop_list) + cls_time = elapse + logger.debug(f"图片 {img_idx+1} cls num: {len(img_crop_list)}, elapsed: {elapse}") + + # 文本识别 + rec_time = 0 + if len(img_crop_list) > 0: + rec_start = time.time() + if len(img_crop_list) > 1000: + logger.debug(f"图片 {img_idx+1} rec crops num: {len(img_crop_list)}, time and memory cost may be large.") + + rec_res, elapse = self.text_recognizer(img_crop_list) + rec_time = elapse + logger.debug(f"图片 {img_idx+1} rec_res num: {len(rec_res)}, elapsed: {elapse}") + + # 结果过滤 + filter_boxes, filter_rec_res = [], [] + for box, rec_result in zip(dt_boxes, rec_res): + text, score = rec_result[0], rec_result[1] + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_result) + + single_time_dict = {"det": det_elapse/len(img_list), "rec": rec_time, "cls": cls_time, "all": 0} + all_results.append((filter_boxes, filter_rec_res, single_time_dict)) + else: + single_time_dict = {"det": det_elapse/len(img_list), "rec": 0, "cls": cls_time, "all": 0} + all_results.append((None, None, single_time_dict)) + + total_time_dict["rec"] += rec_time + total_time_dict["cls"] += cls_time + + end_all = time.time() + total_time_dict["all"] = end_all - start_all + + logger.info(f"批量处理完成,总耗时: {total_time_dict['all']:.3f}s") + return all_results + + def __call__(self, img, cls=True, slice={}, batch_size=None): + """ + 文本系统调用接口,支持单张图片和批处理 + Args: + img: 单张图片或图片列表 + cls: 是否使用角度分类 + slice: 切片参数 + batch_size: 批处理大小 + """ + # 支持批处理 + if isinstance(img, list): + return self._batch_process(img, cls, slice, batch_size) + + # 原有单张图片处理逻辑 time_dict = {"det": 0, "rec": 0, "cls": 0, "all": 0} if img is None: