1
+ '''
2
+ Text detection model: https://github.com/argman/EAST
3
+ Download link: https://www.dropbox.com/s/r2ingd0l3zt8hxs/frozen_east_text_detection.tar.gz?dl=1
4
+ Text recognition model taken from here: https://github.com/meijieru/crnn.pytorch
5
+ How to convert from pb to onnx:
6
+ Using classes from here: https://github.com/meijieru/crnn.pytorch/blob/master/models/crnn.py
7
+ import torch
8
+ import models.crnn as CRNN
9
+ model = CRNN(32, 1, 37, 256)
10
+ model.load_state_dict(torch.load('crnn.pth'))
11
+ dummy_input = torch.randn(1, 1, 32, 100)
12
+ torch.onnx.export(model, dummy_input, "crnn.onnx", verbose=True)
13
+ '''
14
+
15
+
1
16
# Import required modules
17
+ import numpy as np
2
18
import cv2 as cv
3
19
import math
4
20
import argparse
5
21
6
22
############ Add argument parser for command line arguments ############
7
- parser = argparse .ArgumentParser (description = 'Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)' )
8
- parser .add_argument ('--input' , help = 'Path to input image or video file. Skip this argument to capture frames from a camera.' )
9
- parser .add_argument ('--model' , required = True ,
10
- help = 'Path to a binary .pb file of model contains trained weights.' )
23
+ parser = argparse .ArgumentParser (
24
+ description = "Use this script to run TensorFlow implementation (https://github.com/argman/EAST) of "
25
+ "EAST: An Efficient and Accurate Scene Text Detector (https://arxiv.org/abs/1704.03155v2)"
26
+ "The OCR model can be obtained from converting the pretrained CRNN model to .onnx format from the github repository https://github.com/meijieru/crnn.pytorch" )
27
+ parser .add_argument ('--input' ,
28
+ help = 'Path to input image or video file. Skip this argument to capture frames from a camera.' )
29
+ parser .add_argument ('--model' , '-m' , required = True ,
30
+ help = 'Path to a binary .pb file contains trained detector network.' )
31
+ parser .add_argument ('--ocr' , default = "crnn.onnx" ,
32
+ help = "Path to a binary .pb or .onnx file contains trained recognition network" , )
11
33
parser .add_argument ('--width' , type = int , default = 320 ,
12
34
help = 'Preprocess input image by resizing to a specific width. It should be multiple by 32.' )
13
- parser .add_argument ('--height' ,type = int , default = 320 ,
35
+ parser .add_argument ('--height' , type = int , default = 320 ,
14
36
help = 'Preprocess input image by resizing to a specific height. It should be multiple by 32.' )
15
- parser .add_argument ('--thr' ,type = float , default = 0.5 ,
37
+ parser .add_argument ('--thr' , type = float , default = 0.5 ,
16
38
help = 'Confidence threshold.' )
17
- parser .add_argument ('--nms' ,type = float , default = 0.4 ,
39
+ parser .add_argument ('--nms' , type = float , default = 0.4 ,
18
40
help = 'Non-maximum suppression threshold.' )
19
41
args = parser .parse_args ()
20
42
43
+
21
44
############ Utility functions ############
22
- def decode (scores , geometry , scoreThresh ):
45
+
46
+ def fourPointsTransform (frame , vertices ):
47
+ vertices = np .asarray (vertices )
48
+ outputSize = (100 , 32 )
49
+ targetVertices = np .array ([
50
+ [0 , outputSize [1 ] - 1 ],
51
+ [0 , 0 ],
52
+ [outputSize [0 ] - 1 , 0 ],
53
+ [outputSize [0 ] - 1 , outputSize [1 ] - 1 ]], dtype = "float32" )
54
+
55
+ rotationMatrix = cv .getPerspectiveTransform (vertices , targetVertices )
56
+ result = cv .warpPerspective (frame , rotationMatrix , outputSize )
57
+ return result
58
+
59
+
60
+ def decodeText (scores ):
61
+ text = ""
62
+ alphabet = "0123456789abcdefghijklmnopqrstuvwxyz"
63
+ for i in range (scores .shape [0 ]):
64
+ c = np .argmax (scores [i ][0 ])
65
+ if c != 0 :
66
+ text += alphabet [c - 1 ]
67
+ else :
68
+ text += '-'
69
+
70
+ # adjacent same letters as well as background text must be removed to get the final output
71
+ char_list = []
72
+ for i in range (len (text )):
73
+ if text [i ] != '-' and (not (i > 0 and text [i ] == text [i - 1 ])):
74
+ char_list .append (text [i ])
75
+ return '' .join (char_list )
76
+
77
+
78
+ def decodeBoundingBoxes (scores , geometry , scoreThresh ):
23
79
detections = []
24
80
confidences = []
25
81
@@ -47,7 +103,7 @@ def decode(scores, geometry, scoreThresh):
47
103
score = scoresData [x ]
48
104
49
105
# If score is lower than threshold score, move to next x
50
- if (score < scoreThresh ):
106
+ if (score < scoreThresh ):
51
107
continue
52
108
53
109
# Calculate offset
@@ -66,24 +122,27 @@ def decode(scores, geometry, scoreThresh):
66
122
67
123
# Find points for rectangle
68
124
p1 = (- sinA * h + offset [0 ], - cosA * h + offset [1 ])
69
- p3 = (- cosA * w + offset [0 ], sinA * w + offset [1 ])
70
- center = (0.5 * (p1 [0 ]+ p3 [0 ]), 0.5 * (p1 [1 ]+ p3 [1 ]))
71
- detections .append ((center , (w ,h ), - 1 * angle * 180.0 / math .pi ))
125
+ p3 = (- cosA * w + offset [0 ], sinA * w + offset [1 ])
126
+ center = (0.5 * (p1 [0 ] + p3 [0 ]), 0.5 * (p1 [1 ] + p3 [1 ]))
127
+ detections .append ((center , (w , h ), - 1 * angle * 180.0 / math .pi ))
72
128
confidences .append (float (score ))
73
129
74
130
# Return detections and confidences
75
131
return [detections , confidences ]
76
132
133
+
77
134
def main ():
78
135
# Read and store arguments
79
136
confThreshold = args .thr
80
137
nmsThreshold = args .nms
81
138
inpWidth = args .width
82
139
inpHeight = args .height
83
- model = args .model
140
+ modelDetector = args .model
141
+ modelRecognition = args .ocr
84
142
85
143
# Load network
86
- net = cv .dnn .readNet (model )
144
+ detector = cv .dnn .readNet (modelDetector )
145
+ recognizer = cv .dnn .readNet (modelRecognition )
87
146
88
147
# Create a new named window
89
148
kWinName = "EAST: An Efficient and Accurate Scene Text Detector"
@@ -95,6 +154,7 @@ def main():
95
154
# Open a video file or an image file or a camera stream
96
155
cap = cv .VideoCapture (args .input if args .input else 0 )
97
156
157
+ tickmeter = cv .TickMeter ()
98
158
while cv .waitKey (1 ) < 0 :
99
159
# Read frame
100
160
hasFrame , frame = cap .read ()
@@ -111,36 +171,61 @@ def main():
111
171
# Create a 4D blob from frame.
112
172
blob = cv .dnn .blobFromImage (frame , 1.0 , (inpWidth , inpHeight ), (123.68 , 116.78 , 103.94 ), True , False )
113
173
114
- # Run the model
115
- net .setInput (blob )
116
- outs = net .forward (outNames )
117
- t , _ = net .getPerfProfile ()
118
- label = 'Inference time: %.2f ms' % (t * 1000.0 / cv .getTickFrequency ())
174
+ # Run the detection model
175
+ detector .setInput (blob )
176
+
177
+ tickmeter .start ()
178
+ outs = detector .forward (outNames )
179
+ tickmeter .stop ()
119
180
120
181
# Get scores and geometry
121
182
scores = outs [0 ]
122
183
geometry = outs [1 ]
123
- [boxes , confidences ] = decode (scores , geometry , confThreshold )
184
+ [boxes , confidences ] = decodeBoundingBoxes (scores , geometry , confThreshold )
124
185
125
186
# Apply NMS
126
- indices = cv .dnn .NMSBoxesRotated (boxes , confidences , confThreshold ,nmsThreshold )
187
+ indices = cv .dnn .NMSBoxesRotated (boxes , confidences , confThreshold , nmsThreshold )
127
188
for i in indices :
128
189
# get 4 corners of the rotated rect
129
190
vertices = cv .boxPoints (boxes [i [0 ]])
130
191
# scale the bounding box coordinates based on the respective ratios
131
192
for j in range (4 ):
132
193
vertices [j ][0 ] *= rW
133
194
vertices [j ][1 ] *= rH
195
+
196
+
197
+ # get cropped image using perspective transform
198
+ if modelRecognition :
199
+ cropped = fourPointsTransform (frame , vertices )
200
+ cropped = cv .cvtColor (cropped , cv .COLOR_BGR2GRAY )
201
+
202
+ # Create a 4D blob from cropped image
203
+ blob = cv .dnn .blobFromImage (cropped , size = (100 , 32 ), mean = 127.5 , scalefactor = 1 / 127.5 )
204
+ recognizer .setInput (blob )
205
+
206
+ # Run the recognition model
207
+ tickmeter .start ()
208
+ result = recognizer .forward ()
209
+ tickmeter .stop ()
210
+
211
+ # decode the result into text
212
+ wordRecognized = decodeText (result )
213
+ cv .putText (frame , wordRecognized , (int (vertices [1 ][0 ]), int (vertices [1 ][1 ])), cv .FONT_HERSHEY_SIMPLEX ,
214
+ 0.5 , (255 , 0 , 0 ))
215
+
134
216
for j in range (4 ):
135
217
p1 = (vertices [j ][0 ], vertices [j ][1 ])
136
218
p2 = (vertices [(j + 1 ) % 4 ][0 ], vertices [(j + 1 ) % 4 ][1 ])
137
219
cv .line (frame , p1 , p2 , (0 , 255 , 0 ), 1 )
138
220
139
221
# Put efficiency information
222
+ label = 'Inference time: %.2f ms' % (tickmeter .getTimeMilli ())
140
223
cv .putText (frame , label , (0 , 15 ), cv .FONT_HERSHEY_SIMPLEX , 0.5 , (0 , 255 , 0 ))
141
224
142
225
# Display the frame
143
- cv .imshow (kWinName ,frame )
226
+ cv .imshow (kWinName , frame )
227
+ tickmeter .reset ()
228
+
144
229
145
230
if __name__ == "__main__" :
146
231
main ()
0 commit comments