@@ -9,160 +9,272 @@ YOLO Object detection
9
9
Heavily derived from https://github.com/ModelDepot/tfjs-yolo-tiny (ModelDepot: modeldepot.io)
10
10
*/
11
11
12
- import * as tf from '@tensorflow/tfjs' ;
13
- import Video from '../utils/Video' ;
14
- import { imgToTensor } from '../utils/imageUtilities' ;
15
-
16
- import CLASS_NAMES from './../utils/COCO_CLASSES' ;
17
-
18
- import {
19
- nonMaxSuppression ,
20
- boxesToCorners ,
21
- head ,
22
- filterBoxes ,
23
- ANCHORS ,
24
- } from './postprocess' ;
25
-
26
- const URL = 'https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json' ;
12
+ import * as tf from "@tensorflow/tfjs" ;
13
+ import CLASS_NAMES from "./../utils/COCO_CLASSES" ;
27
14
28
15
const DEFAULTS = {
29
16
filterBoxesThreshold : 0.01 ,
30
17
IOUThreshold : 0.4 ,
31
18
classProbThreshold : 0.4 ,
19
+ URL = "https://raw.githubusercontent.com/ml5js/ml5-library/master/src/YOLO/model.json" ,
32
20
} ;
33
21
34
- // Size of the video
35
- const imageSize = 416 ;
36
-
37
- class YOLOBase extends Video {
38
- constructor ( video , options , callback ) {
39
- super ( video , imageSize ) ;
22
+ class YOLO {
40
23
24
+ constructor ( options ) {
41
25
this . filterBoxesThreshold = options . filterBoxesThreshold || DEFAULTS . filterBoxesThreshold ;
42
26
this . IOUThreshold = options . IOUThreshold || DEFAULTS . IOUThreshold ;
43
27
this . classProbThreshold = options . classProbThreshold || DEFAULTS . classProbThreshold ;
44
- this . modelReady = false ;
45
- this . isPredicting = false ;
46
- this . loadModel ( callback ) ;
47
- }
28
+ this . modelURL = options . url || DEFAULTS . URL ;
29
+ this . model = null ;
30
+ this . inputWidth = 416 ;
31
+ this . inputHeight = 416 ;
32
+ this . classNames = CLASS_NAMES ;
33
+ this . anchors = [
34
+ [ 0.57273 , 0.677385 ] ,
35
+ [ 1.87446 , 2.06253 ] ,
36
+ [ 3.33843 , 5.47434 ] ,
37
+ [ 7.88282 , 3.52778 ] ,
38
+ [ 9.77052 , 9.16828 ]
39
+ ] ;
40
+ this . scaleX ;
41
+ this . scaleY ;
42
+ this . anchorsLength = this . anchors . length ;
43
+ this . classesLength = this . Params . classNames . length ;
44
+ this . init ( ) ;
48
45
49
- async loadModel ( callback ) {
50
- return this . loadVideo ( ) . then ( async ( ) => {
51
- this . model = await tf . loadModel ( URL ) ;
52
- this . modelReady = true ;
53
- callback ( ) ;
54
- } ) ;
55
46
}
56
47
57
- async detect ( inputOrCallback , cb = null ) {
58
- if ( this . modelReady && this . video && ! this . predicting ) {
59
- let imgToPredict ;
60
- let callback = cb ;
61
- this . isPredicting = true ;
62
-
63
- if ( inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement ) {
64
- imgToPredict = inputOrCallback ;
65
- } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLImageElement || inputOrCallback . elt instanceof HTMLVideoElement ) ) {
66
- imgToPredict = inputOrCallback . elt ; // Handle p5.js image and video.
67
- } else if ( typeof inputOrCallback === 'function' ) {
68
- imgToPredict = this . video ;
69
- callback = inputOrCallback ;
70
- }
48
+ init ( ) {
49
+ // indices tensor to filter the elements later on
50
+ this . indicesTensor = tf . range ( 1 , 846 , 1 , "int32" ) ;
71
51
72
- const input = imgToTensor ( imgToPredict ) ;
52
+ // Grid To Split the raw predictions : Assumes Our Model output is 1 Tensor with 13x13x425
53
+ // gonna hard code all this stuff see if it works
54
+ // this can be done once at the initial phase
55
+ // TODO : make this more modular
73
56
74
- const [ allBoxes , boxConfidence , boxClassProbs ] = tf . tidy ( ( ) => {
75
- const activation = this . model . predict ( input ) ;
76
- const [ boxXY , boxWH , bConfidence , bClassProbs ] = head ( activation , ANCHORS , 80 ) ;
77
- const aBoxes = boxesToCorners ( boxXY , boxWH ) ;
78
- return [ aBoxes , bConfidence , bClassProbs ] ;
79
- } ) ;
57
+ [ this . ConvIndex , this . ConvDims , this . AnchorsTensor ] = tf . tidy ( ( ) => {
58
+ let ConvIndex = tf . range ( 0 , 13 ) ;
59
+ let ConvHeightIndex = tf . tile ( ConvIndex , [ 13 ] ) ;
80
60
81
- const [ boxes , scores , classes ] = await filterBoxes ( allBoxes , boxConfidence , boxClassProbs , this . filterBoxesThreshold ) ;
61
+ let ConvWidthindex = tf . tile ( tf . expandDims ( ConvIndex , 0 ) , [ 13 , 1 ] ) ;
62
+ ConvWidthindex = tf . transpose ( ConvWidthindex ) . flatten ( ) ;
82
63
83
- // If all boxes have been filtered out
84
- if ( boxes == null ) {
85
- return [ ] ;
86
- }
64
+ ConvIndex = tf . transpose ( tf . stack ( [ ConvHeightIndex , ConvWidthindex ] ) ) ;
65
+ ConvIndex = tf . reshape ( ConvIndex , [ 13 , 13 , 1 , 2 ] ) ;
87
66
88
- const width = tf . scalar ( imageSize ) ;
89
- const height = tf . scalar ( imageSize ) ;
90
- const imageDims = tf . stack ( [ height , width , height , width ] ) . reshape ( [ 1 , 4 ] ) ;
91
- const boxesModified = tf . mul ( boxes , imageDims ) ;
67
+ let ConvDims = tf . reshape ( tf . tensor1d ( [ 13 , 13 ] ) , [ 1 , 1 , 1 , 2 ] ) ;
68
+ //AnchorsTensor
69
+ let Aten = tf . tensor2d ( this . anchors ) ;
70
+ let AnchorsTensor = tf . reshape ( Aten , [ 1 , 1 , this . anchorsLength , 2 ] ) ;
92
71
93
- const [ preKeepBoxesArr , scoresArr ] = await Promise . all ( [
94
- boxesModified . data ( ) , scores . data ( ) ,
95
- ] ) ;
72
+ return [ ConvIndex , ConvDims , AnchorsTensor ] ;
73
+ } ) ;
74
+ }
96
75
97
- const [ keepIndx , boxesArr , keepScores ] = nonMaxSuppression (
98
- preKeepBoxesArr ,
99
- scoresArr ,
100
- this . IOUThreshold ,
101
- ) ;
76
+ // takes HTMLCanvasElement || HTMLImageElement ||HTMLVideoElement || ImageData as input
77
+ // outs results obj
78
+ async detect ( input ) {
79
+ const predictions = tf . tidy ( ( ) => {
80
+ const data = this . preProccess ( input ) ;
81
+ const preds = this . model . predict ( data ) ;
82
+ return preds ;
83
+ } )
84
+ const results = await this . postProccess ( predictions ) ;
85
+ return results
86
+ }
102
87
103
- const classesIndxArr = await classes . gather ( tf . tensor1d ( keepIndx , 'int32' ) ) . data ( ) ;
88
+ async loadModel ( ) {
89
+ try {
90
+ this . model = await tf . loadModel ( this . modelURL ) ;
91
+ return true ;
92
+ } catch ( e ) {
93
+ console . log ( e ) ;
94
+ return false ;
95
+ }
96
+ }
104
97
105
- const results = [ ] ;
106
98
107
- classesIndxArr . forEach ( ( classIndx , i ) => {
108
- const classProb = keepScores [ i ] ;
109
- if ( classProb < this . classProbThreshold ) {
110
- return ;
111
- }
99
+ //does not dispose of the model atm
100
+ dispose ( ) {
101
+ tf . disposeconstiables ( ) ;
102
+ }
103
+
104
+ // should be called after loadModel()
105
+ cache ( ) {
106
+ tf . tidy ( ( ) => {
107
+ const dummy = tf . zeros ( [ 0 , 416 , 416 , 3 ] )
108
+ const data = this . model . predict ( dummy )
109
+ } )
110
+ }
112
111
113
- const className = CLASS_NAMES [ classIndx ] ;
114
- let [ y , x , h , w ] = boxesArr [ i ] ;
115
112
116
- y = Math . max ( 0 , y ) ;
117
- x = Math . max ( 0 , x ) ;
118
- h = Math . min ( imageSize , h ) - y ;
119
- w = Math . min ( imageSize , w ) - x ;
120
113
121
- const resultObj = {
122
- className,
123
- classProb,
124
- x : x / imageSize ,
125
- y : y / imageSize ,
126
- w : w / imageSize ,
127
- h : h / imageSize ,
128
- } ;
114
+ preProccess ( input ) {
115
+ let img = tf . fromPixels ( input )
116
+ this . imgWidth = img . shape [ 1 ] ;
117
+ this . imgHeight = img . shape [ 0 ] ;
118
+ img = tf . image . resizeBilinear ( img , [ this . inputHeight , this . inputWidth ] )
119
+ . toFloat ( )
120
+ . div ( tf . scalar ( 255 ) )
121
+ . expandDims ( 0 ) ;
122
+ //Scale Stuff
123
+ this . scaleX = this . imgHeight / this . inputHeight ;
124
+ this . scaleY = this . imgWidth / this . inputWidth ;
125
+ return img
126
+ }
129
127
130
- results . push ( resultObj ) ;
131
- } ) ;
132
128
133
- await tf . nextFrame ( ) ;
134
- this . isPredicting = false ;
129
+ async postProccess ( rawPrediction ) {
130
+
131
+ let results = { totalDetections : 0 , detections : [ ] }
132
+
133
+ const [ boxes , BoxScores , Classes , Indices ] = tf . tidy ( ( ) => {
134
+
135
+ rawPrediction = tf . reshape ( rawPrediction , [ 13 , 13 , this . anchorsLength , this . classesLength + 5 ] ) ;
136
+ // Box Coords
137
+ let BoxXY = tf . sigmoid ( rawPrediction . slice ( [ 0 , 0 , 0 , 0 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) )
138
+ let BoxWH = tf . exp ( rawPrediction . slice ( [ 0 , 0 , 0 , 2 ] , [ 13 , 13 , this . anchorsLength , 2 ] ) )
139
+ // ObjectnessScore
140
+ let BoxConfidence = tf . sigmoid ( rawPrediction . slice ( [ 0 , 0 , 0 , 4 ] , [ 13 , 13 , this . anchorsLength , 1 ] ) )
141
+ // ClassProb
142
+ let BoxClassProbs = tf . softmax ( rawPrediction . slice ( [ 0 , 0 , 0 , 5 ] , [ 13 , 13 , this . anchorsLength , this . classesLength ] ) ) ;
143
+
144
+ // from boxes with xy wh to x1,y1 x2,y2
145
+ // Mainly for NMS + rescaling
146
+ /*
147
+ x1 = x + (h/2)
148
+ y1 = y - (w/2)
149
+ x2 = x - (h/2)
150
+ y2 = y + (w/2)
151
+ */
152
+ // BoxScale
153
+ BoxXY = tf . div ( tf . add ( BoxXY , this . ConvIndex ) , this . ConvDims ) ;
154
+
155
+ BoxWH = tf . div ( tf . mul ( BoxWH , this . AnchorsTensor ) , this . ConvDims ) ;
156
+
157
+ const Div = tf . div ( BoxWH , tf . scalar ( 2 ) )
158
+
159
+ const BoxMins = tf . sub ( BoxXY , Div ) ;
160
+
161
+ const BoxMaxes = tf . add ( BoxXY , Div ) ;
162
+ const Size = [ BoxMins . shape [ 0 ] , BoxMins . shape [ 1 ] , BoxMins . shape [ 2 ] , 1 ] ;
163
+
164
+ // main box tensor
165
+ const boxes = tf . concat ( [ BoxMins . slice ( [ 0 , 0 , 0 , 1 ] , Size ) ,
166
+ BoxMins . slice ( [ 0 , 0 , 0 , 0 ] , Size ) ,
167
+ BoxMaxes . slice ( [ 0 , 0 , 0 , 1 ] , Size ) ,
168
+ BoxMaxes . slice ( [ 0 , 0 , 0 , 0 ] , Size )
169
+ ] , 3 )
170
+ . reshape ( [ 845 , 4 ] )
171
+
172
+
173
+ // Filterboxes by objectness threshold
174
+ // not filtering / getting a mask really
175
+
176
+ BoxConfidence = BoxConfidence . squeeze ( [ 3 ] )
177
+ const ObjectnessMask = tf . greaterEqual ( BoxConfidence , tf . scalar ( this . filterboxesThreshold ) )
178
+
179
+
180
+ // Filterboxes by class probability threshold
181
+ const BoxScores = tf . mul ( BoxConfidence , tf . max ( BoxClassProbs , 3 ) ) ;
182
+ const BoxClassProbMask = tf . greaterEqual ( BoxScores , tf . scalar ( this . classProbThreshold ) ) ;
183
+
184
+ // getting classes indices
185
+ const Classes = tf . argMax ( BoxClassProbs , - 1 )
186
+
187
+
188
+ // Final Mask each elem that survived both filters (0x0 0x1 1x0 = fail ) 1x1 = survived
189
+ const FinalMask = BoxClassProbMask . mul ( ObjectnessMask )
190
+
191
+ const Indices = FinalMask . flatten ( ) . toInt ( ) . mul ( this . IndicesTensor )
192
+ return [ boxes , BoxScores , Classes , Indices ]
193
+ } )
194
+
195
+ //we started at one in the range so we remove 1 now
196
+
197
+ let indicesArr = Array . from ( await Indices . data ( ) ) . filter ( i => i > 0 ) . map ( i => i - 1 ) ;
198
+
199
+ if ( indicesArr . length == 0 ) {
200
+ boxes . dispose ( )
201
+ BoxScores . dispose ( )
202
+ Classes . dispose ( )
203
+ return results
204
+ }
205
+ const indicesTensor = tf . tensor1d ( indicesArr , "int32" ) ;
206
+ let filteredBoxes = boxes . gather ( indicesTensor )
207
+ let filteredScores = BoxScores . flatten ( ) . gather ( indicesTensor )
208
+ let filteredClasses = Classes . flatten ( ) . gather ( indicesTensor )
209
+ boxes . dispose ( )
210
+ BoxScores . dispose ( )
211
+ Classes . dispose ( )
212
+ indicesTensor . dispose ( )
213
+
214
+ //Img Rescale
215
+ const Height = tf . scalar ( this . imgHeight ) ;
216
+ const Width = tf . scalar ( this . imgWidth )
217
+ const ImageDims = tf . stack ( [ Height , Width , Height , Width ] ) . reshape ( [ 1 , 4 ] ) ;
218
+ filteredBoxes = filteredBoxes . mul ( ImageDims )
219
+
220
+ // NonMaxSuppression
221
+ // GreedyNMS
222
+ const [ boxArr , scoreArr , classesArr ] = await Promise . all ( [ filteredBoxes . data ( ) , filteredScores . data ( ) , filteredClasses . data ( ) ] ) ;
223
+ filteredBoxes . dispose ( )
224
+ filteredScores . dispose ( )
225
+ filteredClasses . dispose ( )
226
+
227
+ let zipped = [ ] ;
228
+ for ( let i = 0 ; i < scoreArr . length ; i ++ ) {
229
+ // [Score,x,y,w,h,classindex]
230
+ zipped . push ( [ scoreArr [ i ] , [ boxArr [ 4 * i ] , boxArr [ 4 * i + 1 ] , boxArr [ 4 * i + 2 ] , boxArr [ 4 * i + 3 ] ] , classesArr [ i ] ] ) ;
231
+ }
135
232
136
- if ( callback ) {
137
- callback ( results ) ;
233
+ // Sort by descending order of scores (first index of zipped array)
234
+ const sorted = zipped . sort ( ( a , b ) => b [ 0 ] - a [ 0 ] ) ;
235
+ const selectedBoxes = [ ]
236
+ // Greedily go through boxes in descending score order and only
237
+ // return boxes that are below the IoU threshold.
238
+ sorted . forEach ( box => {
239
+ let Push = true ;
240
+ for ( let i = 0 ; i < selectedBoxes . length ; i ++ ) {
241
+ // Compare IoU of zipped[1], since that is the box coordinates arr
242
+ let w = Math . min ( box [ 1 ] [ 3 ] , selectedBoxes [ i ] [ 1 ] [ 3 ] ) - Math . max ( box [ 1 ] [ 1 ] , selectedBoxes [ i ] [ 1 ] [ 1 ] ) ;
243
+ let h = Math . min ( box [ 1 ] [ 2 ] , selectedBoxes [ i ] [ 1 ] [ 2 ] ) - Math . max ( box [ 1 ] [ 0 ] , selectedBoxes [ i ] [ 1 ] [ 0 ] ) ;
244
+ let Intersection = w < 0 || h < 0 ? 0 : w * h
245
+ let Union = ( box [ 1 ] [ 3 ] - box [ 1 ] [ 1 ] ) * ( box [ 1 ] [ 2 ] - box [ 1 ] [ 0 ] ) + ( selectedBoxes [ i ] [ 1 ] [ 3 ] - selectedBoxes [ i ] [ 1 ] [ 1 ] ) * ( selectedBoxes [ i ] [ 1 ] [ 2 ] - selectedBoxes [ i ] [ 1 ] [ 0 ] ) - Intersection
246
+ let Iou = Intersection / Union
247
+ if ( Iou > this . IOUThreshold ) {
248
+ Push = false ;
249
+ break ;
250
+ }
138
251
}
252
+ if ( Push ) selectedBoxes . push ( box ) ;
253
+ } ) ;
139
254
140
- return results ;
141
- }
142
- console . warn ( 'Model has not finished loading' ) ;
143
- return false ;
144
- }
145
- }
255
+ // final phase
146
256
147
- const YOLO = ( videoOrOptionsOrCallback , optionsOrCallback , cb = ( ) => { } ) => {
148
- let callback = cb ;
149
- let options = { } ;
150
- const video = videoOrOptionsOrCallback ;
257
+ // add any output you want
258
+ for ( let id = 0 ; id < selectedBoxes . length ; id ++ ) {
151
259
152
- if ( typeof videoOrOptionsOrCallback === 'object' ) {
153
- options = videoOrOptionsOrCallback ;
154
- } else if ( typeof videoOrOptionsOrCallback === 'function' ) {
155
- callback = videoOrOptionsOrCallback ;
156
- }
260
+ const classProb = selectedBoxes [ id ] [ 0 ] ;
261
+ const classProbRounded = Math . round ( classProb * 1000 ) / 10
262
+ const className = this . classNames [ selectedBoxes [ id ] [ 2 ] ] ;
263
+ const classIndex = selectedBoxes [ id ] [ 2 ] ;
264
+ const [ x1 , y1 , x2 , y2 ] = selectedBoxes [ id ] [ 1 ] ;
265
+ // Need to get this out
266
+ // TODO : add a hsla color for later visualization
267
+ const resultObj = { id, className, classIndex, classProb, classProbRounded, x1, y1, x2, y2 } ;
268
+ results . detections . push ( resultObj ) ;
269
+ }
270
+ // Misc
271
+ results . totalDetections = results . detections . length ;
272
+ results . scaleX = this . scaleX
273
+ results . scaleY = this . scaleY
157
274
158
- if ( typeof optionsOrCallback === 'object' ) {
159
- options = optionsOrCallback ;
160
- } else if ( typeof optionsOrCallback === 'function' ) {
161
- callback = optionsOrCallback ;
275
+ return results
162
276
}
163
277
164
- return new YOLOBase ( video , options , callback ) ;
165
- } ;
278
+ }
166
279
167
280
export default YOLO ;
168
-
0 commit comments