@@ -16,18 +16,17 @@ const DEFAULTS = {
16
16
IOUThreshold : 0.4 ,
17
17
classProbThreshold : 0.4 ,
18
18
URL : 'https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json' ,
19
-
19
+ imageSize : 416 ,
20
20
} ;
21
21
22
- class YOLO {
22
+ class YOLOBase {
23
23
constructor ( options ) {
24
24
this . filterBoxesThreshold = options . filterBoxesThreshold || DEFAULTS . filterBoxesThreshold ;
25
25
this . IOUThreshold = options . IOUThreshold || DEFAULTS . IOUThreshold ;
26
26
this . classProbThreshold = options . classProbThreshold || DEFAULTS . classProbThreshold ;
27
27
this . modelURL = options . url || DEFAULTS . URL ;
28
28
this . model = null ;
29
- this . inputWidth = 416 ;
30
- this . inputHeight = 416 ;
29
+ this . imageSize = options . imageSize || DEFAULTS . imageSize ;
31
30
this . classNames = CLASS_NAMES ;
32
31
this . anchors = [
33
32
[ 0.57273 , 0.677385 ] ,
@@ -36,8 +35,6 @@ class YOLO {
36
35
[ 7.88282 , 3.52778 ] ,
37
36
[ 9.77052 , 9.16828 ] ,
38
37
] ;
39
- // this.scaleX;
40
- // this.scaleY;
41
38
this . anchorsLength = this . anchors . length ;
42
39
this . classesLength = this . classNames . length ;
43
40
this . init ( ) ;
@@ -71,11 +68,16 @@ class YOLO {
71
68
} ) ;
72
69
}
73
70
74
- // takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
75
- // outs results obj
76
- async detect ( input ) {
71
+ /**
72
+ * Infers through the model.
73
+ * TODO : Optionally takes an endpoint to return an intermediate activation.
74
+ * @param img The image to classify. Can be a tensor or a DOM element image,
75
+ * video, or canvas.
76
+ * img: tf.Tensor3D|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement
77
+ */
78
+ async detect ( img ) {
77
79
const predictions = tf . tidy ( ( ) => {
78
- const data = this . preProccess ( input ) ;
80
+ const data = this . preProccess ( img ) ;
79
81
const preds = this . model . predict ( data ) ;
80
82
return preds ;
81
83
} ) ;
@@ -99,63 +101,73 @@ class YOLO {
99
101
tf . disposeconstiables ( ) ;
100
102
}
101
103
102
- // should be called after loadModel()
103
- cache ( ) {
104
- tf . tidy ( ( ) => {
105
- const dummy = tf . zeros ( [ 0 , 416 , 416 , 3 ] ) ;
106
- this . model . predict ( dummy ) ;
107
- } ) ;
104
+ // should be called after load()
105
+ async cache ( ) {
106
+ const dummy = tf . zeros ( [ 416 , 416 , 3 ] ) ;
107
+ await this . detect ( dummy ) ;
108
+ dummy . dispose ( ) ;
108
109
}
109
110
110
- preProccess ( input ) {
111
- const img = tf . fromPixels ( input ) ;
112
- const [ w , h ] = [ img . shape [ 1 ] , img . shape [ 0 ] ] ;
113
- this . imgWidth = w ;
114
- this . imgHeight = h ;
111
+ preProccess ( img ) {
112
+ let image ;
113
+ if ( ! ( img instanceof tf . Tensor ) ) {
114
+ if ( img instanceof HTMLImageElement || img instanceof HTMLVideoElement ) {
115
+ image = tf . fromPixels ( img ) ;
116
+ } else if ( typeof img === 'object' && ( img . elt instanceof HTMLImageElement || img . elt instanceof HTMLVideoElement ) ) {
117
+ image = tf . fromPixels ( img . elt ) ; // Handle p5.js image and video.
118
+ }
119
+ } else {
120
+ image = img ;
121
+ }
115
122
116
- const img1 = tf . image . resizeBilinear ( img , [ this . inputHeight , this . inputWidth ] ) . toFloat ( ) . div ( tf . scalar ( 255 ) ) . expandDims ( 0 ) ;
123
+ [ this . imgWidth , this . imgHeight ] = [ image . shape [ 1 ] , image . shape [ 0 ] ] ;
117
124
125
+ // Normalize the image from [0, 255] to [0, 1].
126
+ const normalized = image . toFloat ( ) . div ( tf . scalar ( 255 ) ) ;
127
+ let resized = normalized ;
128
+ if ( normalized . shape [ 0 ] !== this . imageSize || normalized . shape [ 1 ] !== this . imageSize ) {
129
+ const alignCorners = true ;
130
+ resized = tf . image . resizeBilinear ( normalized , [ this . imageSize , this . imageSize ] , alignCorners ) ;
131
+ }
132
+ // Reshape to a single-element batch so we can pass it to predict.
133
+ const batched = resized . reshape ( [ 1 , this . imageSize , this . imageSize , 3 ] ) ;
118
134
// Scale Stuff
119
- this . scaleX = this . imgHeight / this . inputHeight ;
120
- this . scaleY = this . imgWidth / this . inputWidth ;
121
- return img1 ;
135
+ // this.scaleX = this.imgHeight / this.inputHeight;
136
+ // this.scaleY = this.imgWidth / this.inputWidth;
137
+ return batched ;
122
138
}
123
139
140
+
141
+ /**
142
+ * postproccessing for the yolo output
143
+ * TODO : make this more modular in preperation for yolov3-tiny
144
+ * @param rawPrediction a 4D tensor 13*13*425
145
+ */
124
146
async postProccess ( rawPrediction ) {
125
- const results = {
126
- totalDetections : 0 ,
127
- detections : [ ] ,
128
- } ;
129
147
const [ boxes , boxScores , classes , Indices ] = tf . tidy ( ( ) => {
130
- const rawPrediction1 = tf . reshape ( rawPrediction , [ 13 , 13 , this . anchorsLength , this . classesLength + 5 ] ) ;
148
+ const reshaped = tf . reshape ( rawPrediction , [ 13 , 13 , this . anchorsLength , this . classesLength + 5 ] ) ;
131
149
// Box Coords
132
- const boxxy = tf . sigmoid ( rawPrediction1 . slice ( [ 0 , 0 , 0 , 0 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) ) ;
133
- const boxwh = tf . exp ( rawPrediction1 . slice ( [ 0 , 0 , 0 , 2 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) ) ;
150
+ const boxxy = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 0 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) ) ;
151
+ const boxwh = tf . exp ( reshaped . slice ( [ 0 , 0 , 0 , 2 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) ) ;
134
152
// ObjectnessScore
135
- const boxConfidence = tf . sigmoid ( rawPrediction1 . slice ( [ 0 , 0 , 0 , 4 ] , [ 13 , 13 , this . anchorsLength , 1 ] ) ) ;
153
+ const boxConfidence = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 4 ] , [ 13 , 13 , this . anchorsLength , 1 ] ) ) ;
136
154
// ClassProb
137
- const boxClassProbs = tf . softmax ( rawPrediction1 . slice ( [ 0 , 0 , 0 , 5 ] , [ 13 , 13 , this . anchorsLength , this . classesLength ] ) ) ;
155
+ const boxClassProbs = tf . softmax ( reshaped . slice ( [ 0 , 0 , 0 , 5 ] , [ 13 , 13 , this . anchorsLength , this . classesLength ] ) ) ;
138
156
139
157
// from boxes with xy wh to x1,y1 x2,y2
158
+ // xy:bounding box center wh:width/Height
140
159
// Mainly for NMS + rescaling
141
- /*
142
- x1 = x + (h/2)
143
- y1 = y - (w/2)
144
- x2 = x - (h/2)
145
- y2 = y + (w/2)
146
- */
147
- // BoxScale
148
- const boxXY1 = tf . div ( tf . add ( boxxy , this . ConvIndex ) , this . ConvDims ) ;
149
-
150
- const boxWH1 = tf . div ( tf . mul ( boxwh , this . AnchorsTensor ) , this . ConvDims ) ;
151
-
152
- const Div = tf . div ( boxWH1 , tf . scalar ( 2 ) ) ;
153
-
154
- const boxMins = tf . sub ( boxXY1 , Div ) ;
155
- const boxMaxes = tf . add ( boxXY1 , Div ) ;
156
-
160
+ // x1 = x + (h/2)
161
+ // y1 = y - (w/2)
162
+ // x2 = x - (h/2)
163
+ // y2 = y + (w/2)
164
+
165
+ const boxxy1 = tf . div ( tf . add ( boxxy , this . ConvIndex ) , this . ConvDims ) ;
166
+ const boxwh1 = tf . div ( tf . mul ( boxwh , this . AnchorsTensor ) , this . ConvDims ) ;
167
+ const div = tf . div ( boxwh1 , tf . scalar ( 2 ) ) ;
168
+ const boxMins = tf . sub ( boxxy1 , div ) ;
169
+ const boxMaxes = tf . add ( boxxy1 , div ) ;
157
170
const size = [ boxMins . shape [ 0 ] , boxMins . shape [ 1 ] , boxMins . shape [ 2 ] , 1 ] ;
158
-
159
171
// main box tensor
160
172
const finalboxes = tf . concat ( [
161
173
boxMins . slice ( [ 0 , 0 , 0 , 1 ] , size ) ,
@@ -166,44 +178,48 @@ class YOLO {
166
178
167
179
// Filterboxes by objectness threshold
168
180
// not filtering / getting a mask really
169
-
170
181
const boxConfidence1 = boxConfidence . squeeze ( [ 3 ] ) ;
171
182
const objectnessMask = tf . greaterEqual ( boxConfidence1 , tf . scalar ( this . filterBoxesThreshold ) ) ;
172
183
173
184
// Filterboxes by class probability threshold
174
185
const boxScores1 = tf . mul ( boxConfidence1 , tf . max ( boxClassProbs , 3 ) ) ;
175
186
const boxClassProbMask = tf . greaterEqual ( boxScores1 , tf . scalar ( this . classProbThreshold ) ) ;
176
187
177
- // getting classes indices
188
+ // getting classes indices
178
189
const classes1 = tf . argMax ( boxClassProbs , - 1 ) ;
179
190
180
- // Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
191
+ // Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
181
192
const finalMask = boxClassProbMask . mul ( objectnessMask ) ;
182
193
183
194
const indices = finalMask . flatten ( ) . toInt ( ) . mul ( this . indicesTensor ) ;
184
195
return [ finalboxes , boxScores1 , classes1 , indices ] ;
185
196
} ) ;
186
197
187
198
// we started at one in the range so we remove 1 now
199
+ // this is where a major bottleneck happens
200
+ // this can be replaced with tf.boolean_mask() if tfjs team implements it
201
+ // thisis also why wehave 2 tf.tidy()'s
202
+ // more info : https://github.com/ModelDepot/tfjs-yolo-tiny/issues/6
188
203
189
204
const indicesArr = Array . from ( await Indices . data ( ) ) . filter ( i => i > 0 ) . map ( i => i - 1 ) ;
190
205
191
206
if ( indicesArr . length === 0 ) {
192
207
boxes . dispose ( ) ;
193
208
boxScores . dispose ( ) ;
194
209
classes . dispose ( ) ;
195
- return results ;
210
+ return [ ] ;
196
211
}
212
+
197
213
const [ filteredBoxes , filteredScores , filteredclasses ] = tf . tidy ( ( ) => {
198
214
const indicesTensor = tf . tensor1d ( indicesArr , 'int32' ) ;
199
215
const filteredBoxes1 = boxes . gather ( indicesTensor ) ;
200
216
const filteredScores1 = boxScores . flatten ( ) . gather ( indicesTensor ) ;
201
217
const filteredclasses1 = classes . flatten ( ) . gather ( indicesTensor ) ;
202
- // Img Rescale
218
+ // Image Rescale
203
219
const Height = tf . scalar ( this . imgHeight ) ;
204
220
const Width = tf . scalar ( this . imgWidth ) ;
205
- // 4
206
221
const ImageDims = tf . stack ( [ Height , Width , Height , Width ] ) . reshape ( [ 1 , 4 ] ) ;
222
+
207
223
const filteredBoxes2 = filteredBoxes1 . mul ( ImageDims ) ;
208
224
return [ filteredBoxes2 , filteredScores1 , filteredclasses1 ] ;
209
225
} ) ;
@@ -240,17 +256,16 @@ class YOLO {
240
256
} ) ;
241
257
242
258
// final phase
243
-
259
+ const detections = [ ] ;
244
260
// add any output you want
245
261
for ( let id = 0 ; id < selectedBoxes . length ; id += 1 ) {
246
262
const classProb = selectedBoxes [ id ] [ 0 ] ;
247
263
const classProbRounded = Math . round ( classProb * 1000 ) / 10 ;
248
264
const className = this . classNames [ selectedBoxes [ id ] [ 2 ] ] ;
249
265
const classIndex = selectedBoxes [ id ] [ 2 ] ;
250
266
const [ y1 , x1 , y2 , x2 ] = selectedBoxes [ id ] [ 1 ] ;
251
- // Need to get this out
252
267
// TODO : add a hsla color for later visualization
253
- const resultObj = {
268
+ const detection = {
254
269
id,
255
270
className,
256
271
classIndex,
@@ -261,14 +276,11 @@ class YOLO {
261
276
x2,
262
277
y2,
263
278
} ;
264
- results . detections . push ( resultObj ) ;
279
+ detections . push ( detection ) ;
265
280
}
266
- // Misc
267
- results . totalDetections = results . detections . length ;
268
- results . scaleX = this . scaleX ;
269
- results . scaleY = this . scaleY ;
270
- return results ;
281
+ return detections ;
271
282
}
272
283
}
273
284
285
+ const YOLO = options => new YOLOBase ( options ) ;
274
286
export default YOLO ;
0 commit comments