Skip to content
Draft

draft #16534

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions paddleocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,9 +727,10 @@ def ocr(
- The preprocess_image function is used to preprocess the input image by applying alpha color replacement, inversion, and binarization if specified.
"""
assert isinstance(img, (np.ndarray, list, str, bytes))
if isinstance(img, list) and det == True:
logger.error("When input a list of images, det must be false")
exit(0)
# 支持多张图片的批处理,移除原有限制
# if isinstance(img, list) and det == True:
# logger.error("When input a list of images, det must be false")
# exit(0)
if cls == True and self.use_angle_cls == False:
logger.warning(
"Since the angle classifier is not initialized, it will not be used during the forward process"
Expand All @@ -753,41 +754,83 @@ def preprocess_image(_image):
_image = binarize_img(_image)
return _image

if det and rec:
if det and rec: # 默认 注意我只测试了这一判断情况
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, slice)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
# 确保img是单张图片(ndarray),而不是图片列表
if isinstance(img, list):
# 如果img是列表,说明是多张图片,需要预处理后再组成列表传入
processed_imgs = [preprocess_image(single_img) for single_img in img]
batch_results = self.__call__(processed_imgs, cls, slice)
# batch_results是一个列表,每个元素是(dt_boxes, rec_res, time_dict)
for dt_boxes, rec_res, _ in batch_results:
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
else:
# 单张图片处理
img = preprocess_image(img)
dt_boxes, rec_res, _ = self.__call__(img, cls, slice)
if not dt_boxes and not rec_res:
ocr_res.append(None)
continue
tmp_res = [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
ocr_res.append(tmp_res)
return ocr_res
elif det and not rec:
ocr_res = []
for img in imgs:
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if dt_boxes.size == 0:
ocr_res.append(None)
continue
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
# 确保img是单张图片(ndarray),而不是图片列表
if isinstance(img, list):
# 如果img是列表,说明是多张图片,需要预处理后再传入
processed_imgs = [preprocess_image(single_img) for single_img in img]
# 对于仅检测模式,需要分别处理每张图片
for processed_img in processed_imgs:
dt_boxes, elapse = self.text_detector(processed_img)
if dt_boxes.size == 0:
ocr_res.append(None)
continue
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
else:
# 单张图片处理
img = preprocess_image(img)
dt_boxes, elapse = self.text_detector(img)
if dt_boxes.size == 0:
ocr_res.append(None)
continue
tmp_res = [box.tolist() for box in dt_boxes]
ocr_res.append(tmp_res)
return ocr_res
else:
ocr_res = []
cls_res = []
for img in imgs:
if not isinstance(img, list):
# 确保img是单张图片(ndarray),而不是图片列表
if isinstance(img, list):
# 如果img是列表,说明是多张图片,需要预处理后再传入
processed_imgs = [preprocess_image(single_img) for single_img in img]
# 对于仅识别模式,需要分别处理每张图片
for processed_img in processed_imgs:
processed_img = [processed_img]
if self.use_angle_cls and cls:
processed_img, cls_res_tmp, elapse = self.text_classifier(processed_img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(processed_img)
ocr_res.append(rec_res)
else:
# 单张图片处理
img = preprocess_image(img)
img = [img]
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if self.use_angle_cls and cls:
img, cls_res_tmp, elapse = self.text_classifier(img)
if not rec:
cls_res.append(cls_res_tmp)
rec_res, elapse = self.text_recognizer(img)
ocr_res.append(rec_res)
if not rec:
return cls_res
return ocr_res
Expand Down
154 changes: 153 additions & 1 deletion tools/infer/predict_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,144 @@ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
dt_boxes = np.array(dt_boxes_new)
return dt_boxes

def predict_batch(self, img_list, batch_size=None):
"""
批量预测多张图片
Args:
img_list: 图片列表
batch_size: 批处理大小,如果为None则处理所有图片
Returns:
batch_dt_boxes: 批量检测框结果列表
elapse: 总耗时
"""
if batch_size is None:
batch_size = len(img_list)

all_dt_boxes = []
total_elapse = 0

# 分批处理
for i in range(0, len(img_list), batch_size):
batch_imgs = img_list[i:i+batch_size]
batch_dt_boxes, elapse = self._predict_single_batch(batch_imgs)
all_dt_boxes.extend(batch_dt_boxes)
total_elapse += elapse

return all_dt_boxes, total_elapse

def _predict_single_batch(self, img_list):
"""
处理单个批次的图片
"""
batch_size = len(img_list)
ori_imgs = [img.copy() for img in img_list]

st = time.time()

if self.args.benchmark:
self.autolog.times.start()

# 批量预处理
batch_data = []
batch_shape_lists = []

for img in img_list:
data = {"image": img}
data = transform(data, self.preprocess_op)
processed_img, shape_list = data
if processed_img is not None:
batch_data.append(processed_img)
batch_shape_lists.append(shape_list)

if not batch_data:
return [None] * batch_size, 0

# 将批量数据组合成一个batch
#batch_imgs = np.array(batch_data) # Error processing batch ['IMG_20241230_162405.HEIC', 'IMG_20241228_171324.HEIC']: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (2, 3) + inhomogeneous part. # {ndarray::(3,960,704)}{ndarray:(3,704,960)}
# 统一批量数据的尺寸 - 填充到最大尺寸
if len(batch_data) > 1:
# 找到最大的高度和宽度
max_h = max(img.shape[1] for img in batch_data) # shape: (C, H, W)
max_w = max(img.shape[2] for img in batch_data)

# 将所有图片填充到相同尺寸
padded_imgs = []
for img in batch_data:
c, h, w = img.shape
# 创建填充后的图片
padded_img = np.zeros((c, max_h, max_w), dtype=img.dtype)
padded_img[:, :h, :w] = img # 左上角对齐
padded_imgs.append(padded_img)

batch_imgs = np.array(padded_imgs)
else:
batch_imgs = np.array(batch_data)

batch_shape_lists = np.array(batch_shape_lists)

if self.args.benchmark:
self.autolog.times.stamp()


# 批量推理
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = batch_imgs
outputs = self.predictor.run(self.output_tensors, input_dict)
else:
self.input_tensor.copy_from_cpu(batch_imgs)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.args.benchmark:
self.autolog.times.stamp()

# 批量后处理
batch_dt_boxes = []

for batch_idx in range(len(batch_data)):
# 为每张图片提取对应的输出
single_outputs = {}
if self.det_algorithm == "EAST":
single_outputs["f_geo"] = outputs[0][batch_idx:batch_idx+1]
single_outputs["f_score"] = outputs[1][batch_idx:batch_idx+1]
elif self.det_algorithm == "SAST":
single_outputs["f_border"] = outputs[0][batch_idx:batch_idx+1]
single_outputs["f_score"] = outputs[1][batch_idx:batch_idx+1]
single_outputs["f_tco"] = outputs[2][batch_idx:batch_idx+1]
single_outputs["f_tvo"] = outputs[3][batch_idx:batch_idx+1]
elif self.det_algorithm in ["DB", "PSE", "DB++"]:
single_outputs["maps"] = outputs[0][batch_idx:batch_idx+1]
elif self.det_algorithm == "FCE":
for i, output in enumerate(outputs):
single_outputs["level_{}".format(i)] = output[batch_idx:batch_idx+1]
elif self.det_algorithm == "CT":
single_outputs["maps"] = outputs[0][batch_idx:batch_idx+1]
single_outputs["score"] = outputs[1][batch_idx:batch_idx+1]
else:
raise NotImplementedError

# 单张图片后处理
single_shape_list = batch_shape_lists[batch_idx:batch_idx+1]
post_result = self.postprocess_op(single_outputs, single_shape_list)
dt_boxes = post_result[0]["points"]

if self.args.det_box_type == "poly":
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_imgs[batch_idx].shape)
else: # 这个分支
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_imgs[batch_idx].shape)

batch_dt_boxes.append(dt_boxes)

if self.args.benchmark:
self.autolog.times.end(stamp=True)

et = time.time()
return batch_dt_boxes, et - st


def predict(self, img):
ori_im = img.copy()
data = {"image": img}
Expand Down Expand Up @@ -293,7 +431,21 @@ def predict(self, img):
et = time.time()
return dt_boxes, et - st

def __call__(self, img, use_slice=False):
def __call__(self, img, use_slice=False, batch_size=None):
"""
文本检测调用接口
Args:
img: 单张图片或图片列表
use_slice: 是否使用切片处理
batch_size: 批处理大小(仅当img为列表时有效)
Returns:
检测框结果和耗时
"""
# 支持批处理
if isinstance(img, list):
return self.predict_batch(img, batch_size)

# 原有的单张图片处理逻辑
# For image like poster with one side much greater than the other side,
# splitting recursively and processing with overlap to enhance performance.
MIN_BOUND_DISTANCE = 50
Expand Down
Loading
Loading