diff --git a/easyocr/recognition.py b/easyocr/recognition.py index 530ef9517e2..4fc817ade60 100644 --- a/easyocr/recognition.py +++ b/easyocr/recognition.py @@ -100,6 +100,9 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu'): model.eval() result = [] + preds_str = [] + preds_max_prob = [] + with torch.no_grad(): for image_tensors in test_loader: batch_size = image_tensors.size(0) @@ -144,6 +147,7 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ else: preds_max_prob.append(np.array([0])) + for pred, pred_max_prob in zip(preds_str, preds_max_prob): confidence_score = custom_mean(pred_max_prob) result.append([pred, confidence_score]) @@ -217,8 +221,11 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\ num_workers=int(workers), collate_fn=AlignCollate_contrast, pin_memory=True) result2 = recognizer_predict(recognizer, converter, test_loader, batch_max_length,\ ignore_idx, char_group_idx, decoder, beamWidth, device = device) - + recognizer.eval() result = [] + + if not test_loader: + return result for i, zipped in enumerate(zip(coord, result1)): box, pred1 = zipped if i in low_confident_idx: