@@ -11,6 +11,7 @@ import * as tf from '@tensorflow/tfjs';
11
11
import Video from './../utils/Video' ;
12
12
import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES' ;
13
13
import { imgToTensor } from '../utils/imageUtilities' ;
14
+ import callCallback from '../utils/callcallback' ;
14
15
15
16
const IMAGESIZE = 224 ;
16
17
const DEFAULTS = {
@@ -29,7 +30,6 @@ class Mobilenet {
29
30
this . mobilenet = null ;
30
31
this . modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json' ;
31
32
this . topKPredictions = 10 ;
32
- this . modelLoaded = false ;
33
33
this . hasAnyTrainedClass = false ;
34
34
this . customModel = null ;
35
35
this . epochs = options . epochs || DEFAULTS . epochs ;
@@ -40,12 +40,8 @@ class Mobilenet {
40
40
this . isPredicting = false ;
41
41
this . mapStringToIndex = [ ] ;
42
42
this . usageType = null ;
43
-
44
- this . loadModel ( ) . then ( ( net ) => {
45
- this . modelLoaded = true ;
46
- this . mobilenetFeatures = net ;
47
- callback ( ) ;
48
- } ) ;
43
+ this . ready = callCallback ( this . loadModel ( ) , callback ) ;
44
+ this . then = this . ready . then ;
49
45
}
50
46
51
47
async loadModel ( ) {
@@ -54,20 +50,21 @@ class Mobilenet {
54
50
if ( this . video ) {
55
51
tf . tidy ( ( ) => this . mobilenet . predict ( imgToTensor ( this . video ) ) ) ; // Warm up
56
52
}
57
- return tf . model ( { inputs : this . mobilenet . inputs , outputs : layer . output } ) ;
53
+ this . mobilenetFeatures = await tf . model ( { inputs : this . mobilenet . inputs , outputs : layer . output } ) ;
54
+ return this ;
58
55
}
59
56
60
57
classification ( video , callback ) {
61
58
this . usageType = 'classifier' ;
62
- return this . loadVideo ( video , callback ) ;
59
+ return callCallback ( this . loadVideo ( video ) , callback ) ;
63
60
}
64
61
65
62
regression ( video , callback ) {
66
63
this . usageType = 'regressor' ;
67
- return this . loadVideo ( video , callback ) ;
64
+ return callCallback ( this . loadVideo ( video ) , callback ) ;
68
65
}
69
66
70
- loadVideo ( video , callback = ( ) => { } ) {
67
+ async loadVideo ( video ) {
71
68
let inputVideo = null ;
72
69
73
70
if ( video instanceof HTMLVideoElement ) {
@@ -78,16 +75,13 @@ class Mobilenet {
78
75
79
76
if ( inputVideo ) {
80
77
const vid = new Video ( inputVideo , IMAGESIZE ) ;
81
- vid . loadVideo ( ) . then ( async ( ) => {
82
- this . video = vid . video ;
83
- callback ( ) ;
84
- } ) ;
78
+ this . video = await vid . loadVideo ( ) ;
85
79
}
86
80
87
81
return this ;
88
82
}
89
83
90
- addImage ( inputOrLabel , labelOrCallback , cb = ( ) => { } ) {
84
+ async addImage ( inputOrLabel , labelOrCallback , cb ) {
91
85
let imgToAdd ;
92
86
let label ;
93
87
let callback = cb ;
@@ -115,38 +109,37 @@ class Mobilenet {
115
109
}
116
110
}
117
111
118
- if ( this . modelLoaded ) {
119
- tf . tidy ( ( ) => {
120
- const processedImg = imgToTensor ( imgToAdd ) ;
121
- const prediction = this . mobilenetFeatures . predict ( processedImg ) ;
122
-
123
- let y ;
124
- if ( this . usageType === 'classifier' ) {
125
- y = tf . tidy ( ( ) => tf . oneHot ( tf . tensor1d ( [ label ] , 'int32' ) , this . numClasses ) ) ;
126
- } else if ( this . usageType === 'regressor' ) {
127
- y = tf . tidy ( ( ) => tf . tensor2d ( [ [ label ] ] ) ) ;
128
- }
129
-
130
- if ( this . xs == null ) {
131
- this . xs = tf . keep ( prediction ) ;
132
- this . ys = tf . keep ( y ) ;
133
- this . hasAnyTrainedClass = true ;
134
- } else {
135
- const oldX = this . xs ;
136
- this . xs = tf . keep ( oldX . concat ( prediction , 0 ) ) ;
137
- const oldY = this . ys ;
138
- this . ys = tf . keep ( oldY . concat ( y , 0 ) ) ;
139
- oldX . dispose ( ) ;
140
- oldY . dispose ( ) ;
141
- y . dispose ( ) ;
142
- }
143
- } ) ;
144
- if ( callback ) {
145
- callback ( ) ;
112
+ return callCallback ( this . addImageInternal ( imgToAdd , label ) , callback ) ;
113
+ }
114
+
115
+ async addImageInternal ( imgToAdd , label ) {
116
+ await this . ready ;
117
+ tf . tidy ( ( ) => {
118
+ const processedImg = imgToTensor ( imgToAdd ) ;
119
+ const prediction = this . mobilenetFeatures . predict ( processedImg ) ;
120
+
121
+ let y ;
122
+ if ( this . usageType === 'classifier' ) {
123
+ y = tf . tidy ( ( ) => tf . oneHot ( tf . tensor1d ( [ label ] , 'int32' ) , this . numClasses ) ) ;
124
+ } else if ( this . usageType === 'regressor' ) {
125
+ y = tf . tensor2d ( [ [ label ] ] ) ;
146
126
}
147
- } else {
148
- console . warn ( 'The model is not loaded yet.' ) ;
149
- }
127
+
128
+ if ( this . xs == null ) {
129
+ this . xs = tf . keep ( prediction ) ;
130
+ this . ys = tf . keep ( y ) ;
131
+ this . hasAnyTrainedClass = true ;
132
+ } else {
133
+ const oldX = this . xs ;
134
+ this . xs = tf . keep ( oldX . concat ( prediction , 0 ) ) ;
135
+ const oldY = this . ys ;
136
+ this . ys = tf . keep ( oldY . concat ( y , 0 ) ) ;
137
+ oldX . dispose ( ) ;
138
+ oldY . dispose ( ) ;
139
+ y . dispose ( ) ;
140
+ }
141
+ } ) ;
142
+ return this ;
150
143
}
151
144
152
145
async train ( onProgress ) {
@@ -203,7 +196,7 @@ class Mobilenet {
203
196
throw new Error ( 'Batch size is 0 or NaN. Please choose a non-zero fraction.' ) ;
204
197
}
205
198
206
- this . customModel . fit ( this . xs , this . ys , {
199
+ return this . customModel . fit ( this . xs , this . ys , {
207
200
batchSize,
208
201
epochs : this . epochs ,
209
202
callbacks : {
@@ -217,83 +210,85 @@ class Mobilenet {
217
210
}
218
211
219
212
/* eslint max-len: ["error", { "code": 180 }] */
220
- async classify ( inputOrCallback , cb = null ) {
221
- if ( this . usageType === 'classifier' ) {
222
- let imgToPredict ;
223
- let callback ;
224
-
225
- if ( inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement ) {
226
- imgToPredict = inputOrCallback ;
227
- } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLImageElement || inputOrCallback . elt instanceof HTMLVideoElement ) ) {
228
- imgToPredict = inputOrCallback . elt ; // p5.js image element
229
- } else if ( typeof inputOrCallback === 'function' ) {
230
- imgToPredict = this . video ;
231
- callback = inputOrCallback ;
232
- }
213
+ async classify ( inputOrCallback , cb ) {
214
+ let imgToPredict ;
215
+ let callback ;
216
+
217
+ if ( inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement ) {
218
+ imgToPredict = inputOrCallback ;
219
+ } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLImageElement || inputOrCallback . elt instanceof HTMLVideoElement ) ) {
220
+ imgToPredict = inputOrCallback . elt ; // p5.js image element
221
+ } else if ( typeof inputOrCallback === 'function' ) {
222
+ imgToPredict = this . video ;
223
+ callback = inputOrCallback ;
224
+ }
233
225
234
- if ( typeof cb === 'function' ) {
235
- callback = cb ;
236
- }
226
+ if ( typeof cb === 'function' ) {
227
+ callback = cb ;
228
+ }
237
229
238
- this . isPredicting = true ;
239
- const predictedClass = tf . tidy ( ( ) => {
240
- const processedImg = imgToTensor ( imgToPredict ) ;
241
- const activation = this . mobilenetFeatures . predict ( processedImg ) ;
242
- const predictions = this . customModel . predict ( activation ) ;
243
- return predictions . as1D ( ) . argMax ( ) ;
244
- } ) ;
245
- let classId = ( await predictedClass . data ( ) ) [ 0 ] ;
246
- await tf . nextFrame ( ) ;
247
- if ( callback ) {
248
- if ( this . mapStringToIndex . length > 0 ) {
249
- classId = this . mapStringToIndex [ classId ] ;
250
- }
251
- callback ( classId ) ;
252
- }
253
- } else {
254
- console . warn ( 'Mobilenet Feature Extraction has not been set to be a classifier.' ) ;
230
+ return callCallback ( this . classifyInternal ( imgToPredict ) , callback ) ;
231
+ }
232
+
233
+ async classifyInternal ( imgToPredict ) {
234
+ if ( this . usageType === 'classifier' ) {
235
+ throw new Error ( 'Mobilenet Feature Extraction has not been set to be a classifier.' ) ;
236
+ }
237
+
238
+ this . isPredicting = true ;
239
+ const predictedClass = tf . tidy ( ( ) => {
240
+ const processedImg = imgToTensor ( imgToPredict ) ;
241
+ const activation = this . mobilenetFeatures . predict ( processedImg ) ;
242
+ const predictions = this . customModel . predict ( activation ) ;
243
+ return predictions . as1D ( ) . argMax ( ) ;
244
+ } ) ;
245
+ let classId = ( await predictedClass . data ( ) ) [ 0 ] ;
246
+ await tf . nextFrame ( ) ;
247
+ if ( this . mapStringToIndex . length > 0 ) {
248
+ classId = this . mapStringToIndex [ classId ] ;
255
249
}
250
+ return classId ;
256
251
}
257
252
258
253
/* eslint max-len: ["error", { "code": 180 }] */
259
- async predict ( inputOrCallback , cb = null ) {
260
- if ( this . usageType === 'regressor' ) {
261
- let imgToPredict ;
262
- let callback ;
263
-
264
- if ( inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement ) {
265
- imgToPredict = inputOrCallback ;
266
- } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLImageElement || inputOrCallback . elt instanceof HTMLVideoElement ) ) {
267
- imgToPredict = inputOrCallback . elt ; // p5.js image element
268
- } else if ( typeof inputOrCallback === 'function' ) {
269
- imgToPredict = this . video ;
270
- callback = inputOrCallback ;
271
- }
254
+ async predict ( inputOrCallback , cb ) {
255
+ let imgToPredict ;
256
+ let callback ;
257
+ if ( inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement ) {
258
+ imgToPredict = inputOrCallback ;
259
+ } else if ( typeof inputOrCallback === 'object' && ( inputOrCallback . elt instanceof HTMLImageElement || inputOrCallback . elt instanceof HTMLVideoElement ) ) {
260
+ imgToPredict = inputOrCallback . elt ; // p5.js image element
261
+ } else if ( typeof inputOrCallback === 'function' ) {
262
+ imgToPredict = this . video ;
263
+ callback = inputOrCallback ;
264
+ }
272
265
273
- if ( typeof cb === 'function' ) {
274
- callback = cb ;
275
- }
266
+ if ( typeof cb === 'function' ) {
267
+ callback = cb ;
268
+ }
269
+ return callCallback ( this . predictInternal ( imgToPredict ) , callback ) ;
270
+ }
276
271
277
- this . isPredicting = true ;
278
- const predictedClass = tf . tidy ( ( ) => {
279
- const processedImg = imgToTensor ( imgToPredict ) ;
280
- const activation = this . mobilenetFeatures . predict ( processedImg ) ;
281
- const predictions = this . customModel . predict ( activation ) ;
282
- return predictions . as1D ( ) ;
283
- } ) ;
284
- const prediction = ( await predictedClass . data ( ) ) ;
285
- predictedClass . dispose ( ) ;
286
- await tf . nextFrame ( ) ;
287
- if ( callback ) {
288
- callback ( prediction [ 0 ] ) ;
289
- }
290
- } else {
291
- console . warn ( 'Mobilenet Feature Extraction has not been set to be a regressor.' ) ;
272
+ async predictInternal ( imgToPredict ) {
273
+ if ( this . usageType !== 'regressor' ) {
274
+ throw new Error ( 'Mobilenet Feature Extraction has not been set to be a regressor.' ) ;
292
275
}
276
+
277
+ this . isPredicting = true ;
278
+ const predictedClass = tf . tidy ( ( ) => {
279
+ const processedImg = imgToTensor ( imgToPredict ) ;
280
+ const activation = this . mobilenetFeatures . predict ( processedImg ) ;
281
+ const predictions = this . customModel . predict ( activation ) ;
282
+ return predictions . as1D ( ) ;
283
+ } ) ;
284
+ const prediction = await predictedClass . data ( ) ;
285
+ predictedClass . dispose ( ) ;
286
+ await tf . nextFrame ( ) ;
287
+ return prediction [ 0 ] ;
293
288
}
294
289
295
290
// Static Method: get top k classes for mobilenet
296
- static async getTopKClasses ( logits , topK , callback ) {
291
+ static async getTopKClasses ( logits , topK , callback = ( ) => { } ) {
297
292
const values = await logits . data ( ) ;
298
293
const valuesAndIndices = [ ] ;
299
294
for ( let i = 0 ; i < values . length ; i += 1 ) {
@@ -317,9 +312,7 @@ class Mobilenet {
317
312
318
313
await tf . nextFrame ( ) ;
319
314
320
- if ( callback ) {
321
- callback ( topClassesAndProbs ) ;
322
- }
315
+ callback ( undefined , topClassesAndProbs ) ;
323
316
return topClassesAndProbs ;
324
317
}
325
318
}
0 commit comments