@@ -20,12 +20,11 @@ const DEFAULTS = {
20
20
} ;
21
21
22
22
class YOLOBase {
23
- constructor ( options ) {
23
+ constructor ( options = DEFAULTS ) {
24
24
this . filterBoxesThreshold = options . filterBoxesThreshold || DEFAULTS . filterBoxesThreshold ;
25
25
this . IOUThreshold = options . IOUThreshold || DEFAULTS . IOUThreshold ;
26
26
this . classProbThreshold = options . classProbThreshold || DEFAULTS . classProbThreshold ;
27
27
this . modelURL = options . url || DEFAULTS . URL ;
28
- this . model = null ;
29
28
this . imageSize = options . imageSize || DEFAULTS . imageSize ;
30
29
this . classNames = CLASS_NAMES ;
31
30
this . anchors = [
@@ -35,20 +34,30 @@ class YOLOBase {
35
34
[ 7.88282 , 3.52778 ] ,
36
35
[ 9.77052 , 9.16828 ] ,
37
36
] ;
38
- this . anchorsLength = this . anchors . length ;
39
- this . classesLength = this . classNames . length ;
40
37
this . init ( ) ;
41
38
}
42
39
43
40
init ( ) {
44
- const Aten = tf . tensor2d ( this . anchors ) ;
45
- this . anchorsTensor = tf . reshape ( Aten , [ 1 , 1 , this . anchorsLength , 2 ] ) ;
46
- Aten . dispose ( ) ;
41
+ const outputWidth = 13 ;
42
+ const outputHeight = 13 ;
43
+
44
+ [ this . convIndex , this . convDims , this . anchorsTensor ] = tf . tidy ( ( ) => {
45
+ const Atensor = tf . tensor2d ( this . anchors ) ;
46
+ const anchorsTensor = tf . reshape ( Atensor , [ 1 , 1 , Atensor . shape [ 0 ] , 2 ] ) ;
47
+
48
+ let convIndex = tf . range ( 0 , outputWidth ) ;
49
+ const convHeightIndex = tf . tile ( convIndex , [ outputHeight ] ) ;
50
+ let convWidthindex = tf . tile ( tf . expandDims ( convIndex , 0 ) , [ outputWidth , 1 ] ) ;
51
+ convWidthindex = tf . transpose ( convWidthindex ) . flatten ( ) ;
52
+ convIndex = tf . transpose ( tf . stack ( [ convHeightIndex , convWidthindex ] ) ) ;
53
+ convIndex = tf . reshape ( convIndex , [ outputWidth , outputHeight , 1 , 2 ] ) ;
54
+ const convDims = tf . reshape ( tf . tensor1d ( [ outputWidth , outputHeight ] ) , [ 1 , 1 , 1 , 2 ] ) ;
55
+ return [ convIndex , convDims , anchorsTensor ] ;
56
+ } ) ;
47
57
}
48
58
49
59
/**
50
60
* Infers through the model.
51
- * TODO : Optionally takes an endpoint to return an intermediate activation.
52
61
* @param img The image to classify. Can be a tensor or a DOM element image,
53
62
* video, or canvas.
54
63
* img: tf.Tensor3D|ImageData|HTMLImageElement|HTMLCanvasElement|HTMLVideoElement
@@ -60,6 +69,7 @@ class YOLOBase {
60
69
return preds ;
61
70
} ) ;
62
71
const results = await this . postProcess ( predictions ) ;
72
+ predictions . dispose ( ) ;
63
73
return results ;
64
74
}
65
75
@@ -122,8 +132,7 @@ class YOLOBase {
122
132
* @param rawPrediction a 4D tensor
123
133
*/
124
134
async postProcess ( rawPrediction ) {
125
- const [ boxes , boxScores , classes , Indices ] = tf . tidy ( ( ) => this . split ( rawPrediction . squeeze ( [ 0 ] ) , this . anchorsTensor ) ) ;
126
-
135
+ const [ boxes , boxScores , classes , Indices ] = tf . tidy ( ( ) => this . split ( rawPrediction . squeeze ( [ 0 ] ) ) ) ;
127
136
// we started at one in the range so we remove 1 now
128
137
const indicesArr = Array . from ( await Indices . data ( ) ) . filter ( i => i > 0 ) . map ( i => i - 1 ) ;
129
138
@@ -148,6 +157,7 @@ class YOLOBase {
148
157
// this for x y w h
149
158
const ImageDims = tf . stack ( [ Width , Height , Width , Height ] ) . reshape ( [ 1 , 4 ] ) ;
150
159
const filteredBoxes2 = filteredBoxes1 . mul ( ImageDims ) ;
160
+
151
161
return [ filteredBoxes2 , filteredScores1 , filteredclasses1 ] ;
152
162
} ) ;
153
163
boxes . dispose ( ) ;
@@ -218,46 +228,42 @@ class YOLOBase {
218
228
return detections ;
219
229
}
220
230
221
- split ( rawPrediction , AnchorsTensor ) {
231
+ split ( rawPrediction ) {
222
232
const [ outputWidth , outputHeight ] = [ rawPrediction . shape [ 0 ] , rawPrediction . shape [ 1 ] ] ;
223
- const reshaped = tf . reshape ( rawPrediction , [ outputWidth , outputHeight , this . anchorsLength , this . classesLength + 5 ] ) ;
224
- // Box xywh
225
- const boxxy = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 0 ] , [ outputWidth , outputHeight , this . anchorsLength , 2 ] ) ) ;
226
- const boxwh = tf . exp ( reshaped . slice ( [ 0 , 0 , 0 , 2 ] , [ outputWidth , outputHeight , this . anchorsLength , 2 ] ) ) ;
233
+ const anchorsLength = this . anchorsTensor . shape [ 2 ] ;
234
+ const classesLength = this . classNames . length ;
235
+ const reshaped = tf . reshape ( rawPrediction , [ outputWidth , outputHeight , anchorsLength , classesLength + 5 ] ) ;
236
+ // Box xy_wh
237
+ let boxxy = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 0 ] , [ outputWidth , outputHeight , anchorsLength , 2 ] ) ) ;
238
+ let boxwh = tf . exp ( reshaped . slice ( [ 0 , 0 , 0 , 2 ] , [ outputWidth , outputHeight , anchorsLength , 2 ] ) ) ;
227
239
// objectnessScore
228
- const boxConfidence = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 4 ] , [ outputWidth , outputHeight , this . anchorsLength , 1 ] ) ) ;
240
+ let boxConfidence = tf . sigmoid ( reshaped . slice ( [ 0 , 0 , 0 , 4 ] , [ outputWidth , outputHeight , anchorsLength , 1 ] ) ) ;
229
241
// classProb
230
- const boxClassProbs = tf . softmax ( reshaped . slice ( [ 0 , 0 , 0 , 5 ] , [ outputWidth , outputHeight , this . anchorsLength , this . classesLength ] ) ) ;
231
-
232
- let ConvIndex = tf . range ( 0 , outputWidth ) ;
233
- const ConvHeightIndex = tf . tile ( ConvIndex , [ outputHeight ] ) ;
234
- let ConvWidthindex = tf . tile ( tf . expandDims ( ConvIndex , 0 ) , [ outputWidth , 1 ] ) ;
235
- ConvWidthindex = tf . transpose ( ConvWidthindex ) . flatten ( ) ;
236
- ConvIndex = tf . transpose ( tf . stack ( [ ConvHeightIndex , ConvWidthindex ] ) ) ;
237
- ConvIndex = tf . reshape ( ConvIndex , [ outputWidth , outputHeight , 1 , 2 ] ) ;
238
- const ConvDims = tf . reshape ( tf . tensor1d ( [ outputWidth , outputHeight ] ) , [ 1 , 1 , 1 , 2 ] ) ;
242
+ const boxClassProbs = tf . softmax ( reshaped . slice ( [ 0 , 0 , 0 , 5 ] , [ outputWidth , outputHeight , anchorsLength , classesLength ] ) ) ;
239
243
240
- const boxxy1 = tf . div ( tf . add ( boxxy , ConvIndex ) , ConvDims ) ;
241
- const boxwh1 = tf . div ( tf . mul ( boxwh , AnchorsTensor ) , ConvDims ) ;
244
+ boxxy = tf . div ( tf . add ( boxxy , this . convIndex ) , this . convDims ) ;
245
+ boxwh = tf . div ( tf . mul ( boxwh , this . anchorsTensor ) , this . convDims ) ;
242
246
243
- const finalboxes = tf . concat ( [ boxxy1 , boxwh1 ] , 3 ) . reshape ( [ ( outputWidth * outputHeight * this . anchorsLength ) , 4 ] ) ;
247
+ const finalboxes = tf . concat ( [ boxxy , boxwh ] , 3 ) . reshape ( [ ( outputWidth * outputHeight * anchorsLength ) , 4 ] ) ;
244
248
245
249
// filter boxes by objectness threshold
246
- const boxConfidence1 = boxConfidence . squeeze ( [ 3 ] ) ;
247
- const objectnessMask = tf . greaterEqual ( boxConfidence1 , tf . scalar ( this . filterBoxesThreshold ) ) ;
250
+ boxConfidence = boxConfidence . squeeze ( [ 3 ] ) ;
251
+ const objectnessMask = tf . greaterEqual ( boxConfidence , tf . scalar ( this . filterBoxesThreshold ) ) ;
248
252
249
253
// filter boxes by class probability threshold
250
- const boxScores1 = tf . mul ( boxConfidence1 , tf . max ( boxClassProbs , 3 ) ) ;
251
- const boxClassProbMask = tf . greaterEqual ( boxScores1 , tf . scalar ( this . classProbThreshold ) ) ;
254
+ const boxScores = tf . mul ( boxConfidence , tf . max ( boxClassProbs , 3 ) ) ;
255
+ const boxClassProbMask = tf . greaterEqual ( boxScores , tf . scalar ( this . classProbThreshold ) ) ;
252
256
253
257
// classes indices
254
- const classes1 = tf . argMax ( boxClassProbs , - 1 ) ;
258
+ const classes = tf . argMax ( boxClassProbs , - 1 ) ;
255
259
256
- const indicesTensor = tf . range ( 1 , ( outputWidth * outputHeight * this . anchorsLength ) + 1 , 1 , 'int32' ) ;
260
+ const indicesTensor = tf . range ( 1 , ( outputWidth * outputHeight * anchorsLength ) + 1 , 1 , 'int32' ) ;
257
261
// Final Mask each elem that survived both filters
258
262
const finalMask = boxClassProbMask . mul ( objectnessMask ) ;
263
+
259
264
const indices = finalMask . flatten ( ) . toInt ( ) . mul ( indicesTensor ) ;
260
- return [ finalboxes , boxScores1 , classes1 , indices ] ;
265
+
266
+ return [ finalboxes , boxScores , classes , indices ] ;
261
267
}
262
268
}
263
269
0 commit comments