Skip to content

Commit 1b336bb

Browse files
Merge pull request opencv#16955 from themechanicalcoder:text_recognition
* add text recognition sample * fix pylint warning * made changes according to the c++ example * fix errors * add text recognition sample * update text detection sample
1 parent 0fb3b8d commit 1b336bb

File tree

1 file changed

+107
-22
lines changed

1 file changed

+107
-22
lines changed

samples/dnn/text_detection.py

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,81 @@
1+
'''
2+
Text detection model: https://github.com/argman/EAST
3+
Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
4+
Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
5+
How to convert from pb to onnx:
6+
Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
7+
import torch
8+
import models.crnn as CRNN
9+
model = CRNN(32, 1, 37, 256)
10+
model.load_state_dict(torch.load('crnn.pth'))
11+
dummy_input = torch.randn(1, 1, 32, 100)
12+
torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
13+
'''
14+
15+
116
# Import required modules
17+
import numpy as np
218
import cv2 as cv
319
import math
420
import argparse
521

622
############ Add argument parser for command line arguments ############
7-
parser = argparse.ArgumentParser(description='Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)')
8-
parser.add_argument('--input', help='Path to input image or video file. Skip this argument to capture frames from a camera.')
9-
parser.add_argument('--model', required=True,
10-
help='Path to a binary .pb file of model contains trained weights.')
23+
parser = argparse.ArgumentParser(
24+
description="Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
25+
"EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
26+
"The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch")
27+
parser.add_argument('--input',
28+
help='Path to input image or video file. Skip this argument to capture frames from a camera.')
29+
parser.add_argument('--model', '-m', required=True,
30+
help='Path to a binary .pb file contains trained detector network.')
31+
parser.add_argument('--ocr', default="crnn.onnx",
32+
help="Path to a binary .pb or .onnx file contains trained recognition network", )
1133
parser.add_argument('--width', type=int, default=320,
1234
help='Preprocess input image by resizing to a specific width. It should be multiple by 32.')
13-
parser.add_argument('--height',type=int, default=320,
35+
parser.add_argument('--height', type=int, default=320,
1436
help='Preprocess input image by resizing to a specific height. It should be multiple by 32.')
15-
parser.add_argument('--thr',type=float, default=0.5,
37+
parser.add_argument('--thr', type=float, default=0.5,
1638
help='Confidence threshold.')
17-
parser.add_argument('--nms',type=float, default=0.4,
39+
parser.add_argument('--nms', type=float, default=0.4,
1840
help='Non-maximum suppression threshold.')
1941
args = parser.parse_args()
2042

43+
2144
############ Utility functions ############
22-
def decode(scores, geometry, scoreThresh):
45+
46+
def fourPointsTransform(frame, vertices):
47+
vertices = np.asarray(vertices)
48+
outputSize = (100, 32)
49+
targetVertices = np.array([
50+
[0, outputSize[1] - 1],
51+
[0, 0],
52+
[outputSize[0] - 1, 0],
53+
[outputSize[0] - 1, outputSize[1] - 1]], dtype="float32")
54+
55+
rotationMatrix = cv.getPerspectiveTransform(vertices, targetVertices)
56+
result = cv.warpPerspective(frame, rotationMatrix, outputSize)
57+
return result
58+
59+
60+
def decodeText(scores):
61+
text = ""
62+
alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
63+
for i in range(scores.shape[0]):
64+
c = np.argmax(scores[i][0])
65+
if c != 0:
66+
text += alphabet[c - 1]
67+
else:
68+
text += '-'
69+
70+
# adjacent same letters as well as background text must be removed to get the final output
71+
char_list = []
72+
for i in range(len(text)):
73+
if text[i] != '-' and (not (i > 0 and text[i] == text[i - 1])):
74+
char_list.append(text[i])
75+
return ''.join(char_list)
76+
77+
78+
def decodeBoundingBoxes(scores, geometry, scoreThresh):
2379
detections = []
2480
confidences = []
2581

@@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
47103
score = scoresData[x]
48104

49105
# If score is lower than threshold score, move to next x
50-
if(score < scoreThresh):
106+
if (score < scoreThresh):
51107
continue
52108

53109
# Calculate offset
@@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
66122

67123
# Find points for rectangle
68124
p1 = (-sinA * h + offset[0], -cosA * h + offset[1])
69-
p3 = (-cosA * w + offset[0], sinA * w + offset[1])
70-
center = (0.5*(p1[0]+p3[0]), 0.5*(p1[1]+p3[1]))
71-
detections.append((center, (w,h), -1*angle * 180.0 / math.pi))
125+
p3 = (-cosA * w + offset[0], sinA * w + offset[1])
126+
center = (0.5 * (p1[0] + p3[0]), 0.5 * (p1[1] + p3[1]))
127+
detections.append((center, (w, h), -1 * angle * 180.0 / math.pi))
72128
confidences.append(float(score))
73129

74130
# Return detections and confidences
75131
return [detections, confidences]
76132

133+
77134
def main():
78135
# Read and store arguments
79136
confThreshold = args.thr
80137
nmsThreshold = args.nms
81138
inpWidth = args.width
82139
inpHeight = args.height
83-
model = args.model
140+
modelDetector = args.model
141+
modelRecognition = args.ocr
84142

85143
# Load network
86-
net = cv.dnn.readNet(model)
144+
detector = cv.dnn.readNet(modelDetector)
145+
recognizer = cv.dnn.readNet(modelRecognition)
87146

88147
# Create a new named window
89148
kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
@@ -95,6 +154,7 @@ def main():
95154
# Open a video file or an image file or a camera stream
96155
cap = cv.VideoCapture(args.input if args.input else 0)
97156

157+
tickmeter = cv.TickMeter()
98158
while cv.waitKey(1) < 0:
99159
# Read frame
100160
hasFrame, frame = cap.read()
@@ -111,36 +171,61 @@ def main():
111171
# Create a 4D blob from frame.
112172
blob = cv.dnn.blobFromImage(frame, 1.0, (inpWidth, inpHeight), (123.68, 116.78, 103.94), True, False)
113173

114-
# Run the model
115-
net.setInput(blob)
116-
outs = net.forward(outNames)
117-
t, _ = net.getPerfProfile()
118-
label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency())
174+
# Run the detection model
175+
detector.setInput(blob)
176+
177+
tickmeter.start()
178+
outs = detector.forward(outNames)
179+
tickmeter.stop()
119180

120181
# Get scores and geometry
121182
scores = outs[0]
122183
geometry = outs[1]
123-
[boxes, confidences] = decode(scores, geometry, confThreshold)
184+
[boxes, confidences] = decodeBoundingBoxes(scores, geometry, confThreshold)
124185

125186
# Apply NMS
126-
indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold,nmsThreshold)
187+
indices = cv.dnn.NMSBoxesRotated(boxes, confidences, confThreshold, nmsThreshold)
127188
for i in indices:
128189
# get 4 corners of the rotated rect
129190
vertices = cv.boxPoints(boxes[i[0]])
130191
# scale the bounding box coordinates based on the respective ratios
131192
for j in range(4):
132193
vertices[j][0] *= rW
133194
vertices[j][1] *= rH
195+
196+
197+
# get cropped image using perspective transform
198+
if modelRecognition:
199+
cropped = fourPointsTransform(frame, vertices)
200+
cropped = cv.cvtColor(cropped, cv.COLOR_BGR2GRAY)
201+
202+
# Create a 4D blob from cropped image
203+
blob = cv.dnn.blobFromImage(cropped, size=(100, 32), mean=127.5, scalefactor=1 / 127.5)
204+
recognizer.setInput(blob)
205+
206+
# Run the recognition model
207+
tickmeter.start()
208+
result = recognizer.forward()
209+
tickmeter.stop()
210+
211+
# decode the result into text
212+
wordRecognized = decodeText(result)
213+
cv.putText(frame, wordRecognized, (int(vertices[1][0]), int(vertices[1][1])), cv.FONT_HERSHEY_SIMPLEX,
214+
0.5, (255, 0, 0))
215+
134216
for j in range(4):
135217
p1 = (vertices[j][0], vertices[j][1])
136218
p2 = (vertices[(j + 1) % 4][0], vertices[(j + 1) % 4][1])
137219
cv.line(frame, p1, p2, (0, 255, 0), 1)
138220

139221
# Put efficiency information
222+
label = 'Inference time: %.2f ms' % (tickmeter.getTimeMilli())
140223
cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0))
141224

142225
# Display the frame
143-
cv.imshow(kWinName,frame)
226+
cv.imshow(kWinName, frame)
227+
tickmeter.reset()
228+
144229

145230
if __name__ == "__main__":
146231
main()

0 commit comments

Comments
 (0)