@@ -8,49 +8,53 @@ A class that extract features from Mobilenet
8
8
*/
9
9
10
10
import * as tf from '@tensorflow/tfjs' ;
11
+ import * as mobilenet from '@tensorflow-models/mobilenet' ;
12
+
11
13
import Video from './../utils/Video' ;
12
- import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES' ;
14
+
13
15
import { imgToTensor } from '../utils/imageUtilities' ;
14
16
import callCallback from '../utils/callcallback' ;
15
17
16
- const IMAGESIZE = 224 ;
18
+ const IMAGE_SIZE = 224 ;
17
19
const DEFAULTS = {
18
20
version : 1 ,
19
- alpha : 1.0 ,
21
+ alpha : 0.25 ,
20
22
topk : 3 ,
21
23
learningRate : 0.0001 ,
22
24
hiddenUnits : 100 ,
23
25
epochs : 20 ,
24
26
numClasses : 2 ,
25
27
batchSize : 0.4 ,
28
+ layer : 'conv_pw_13_relu' ,
26
29
} ;
27
30
28
31
class Mobilenet {
29
32
constructor ( options , callback ) {
30
- this . mobilenet = null ;
31
- this . modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json' ;
33
+ this . mobilenet = mobilenet ;
32
34
this . topKPredictions = 10 ;
33
35
this . hasAnyTrainedClass = false ;
34
36
this . customModel = null ;
35
37
this . epochs = options . epochs || DEFAULTS . epochs ;
38
+ this . version = options . version || DEFAULTS . version ;
36
39
this . hiddenUnits = options . hiddenUnits || DEFAULTS . hiddenUnits ;
37
40
this . numClasses = options . numClasses || DEFAULTS . numClasses ;
38
41
this . learningRate = options . learningRate || DEFAULTS . learningRate ;
39
42
this . batchSize = options . batchSize || DEFAULTS . batchSize ;
43
+ this . layer = options . layer || DEFAULTS . layer ;
44
+ this . alpha = options . alpha || DEFAULTS . alpha ;
40
45
this . isPredicting = false ;
41
46
this . mapStringToIndex = [ ] ;
42
47
this . usageType = null ;
43
48
this . ready = callCallback ( this . loadModel ( ) , callback ) ;
44
- // this.then = this.ready.then;
45
49
}
46
50
47
51
async loadModel ( ) {
48
- this . mobilenet = await tf . loadModel ( this . modelPath ) ;
49
- const layer = this . mobilenet . getLayer ( 'conv_pw_13_relu' ) ;
52
+ this . mobilenet = await this . mobilenet . load ( this . version , this . alpha ) ;
53
+ const layer = this . mobilenet . model . getLayer ( this . layer ) ;
54
+ this . mobilenetFeatures = await tf . model ( { inputs : this . mobilenet . model . inputs , outputs : layer . output } ) ;
50
55
if ( this . video ) {
51
- tf . tidy ( ( ) => this . mobilenet . predict ( imgToTensor ( this . video ) ) ) ; // Warm up
56
+ await this . mobilenet . classify ( imgToTensor ( this . video ) ) ; // Warm up
52
57
}
53
- this . mobilenetFeatures = await tf . model ( { inputs : this . mobilenet . inputs , outputs : layer . output } ) ;
54
58
return this ;
55
59
}
56
60
@@ -80,7 +84,7 @@ class Mobilenet {
80
84
}
81
85
82
86
if ( inputVideo ) {
83
- const vid = new Video ( inputVideo , IMAGESIZE ) ;
87
+ const vid = new Video ( inputVideo , IMAGE_SIZE ) ;
84
88
this . video = await vid . loadVideo ( ) ;
85
89
}
86
90
@@ -121,10 +125,9 @@ class Mobilenet {
121
125
async addImageInternal ( imgToAdd , label ) {
122
126
await this . ready ;
123
127
tf . tidy ( ( ) => {
124
- const imageResize = ( imgToAdd === this . video ) ? null : [ IMAGESIZE , IMAGESIZE ] ;
128
+ const imageResize = ( imgToAdd === this . video ) ? null : [ IMAGE_SIZE , IMAGE_SIZE ] ;
125
129
const processedImg = imgToTensor ( imgToAdd , imageResize ) ;
126
130
const prediction = this . mobilenetFeatures . predict ( processedImg ) ;
127
-
128
131
let y ;
129
132
if ( this . usageType === 'classifier' ) {
130
133
y = tf . tidy ( ( ) => tf . oneHot ( tf . tensor1d ( [ label ] , 'int32' ) , this . numClasses ) ) ;
@@ -244,7 +247,7 @@ class Mobilenet {
244
247
await tf . nextFrame ( ) ;
245
248
this . isPredicting = true ;
246
249
const predictedClass = tf . tidy ( ( ) => {
247
- const imageResize = ( imgToPredict === this . video ) ? null : [ IMAGESIZE , IMAGESIZE ] ;
250
+ const imageResize = ( imgToPredict === this . video ) ? null : [ IMAGE_SIZE , IMAGE_SIZE ] ;
248
251
const processedImg = imgToTensor ( imgToPredict , imageResize ) ;
249
252
const activation = this . mobilenetFeatures . predict ( processedImg ) ;
250
253
const predictions = this . customModel . predict ( activation ) ;
@@ -283,7 +286,7 @@ class Mobilenet {
283
286
await tf . nextFrame ( ) ;
284
287
this . isPredicting = true ;
285
288
const predictedClass = tf . tidy ( ( ) => {
286
- const imageResize = ( imgToPredict === this . video ) ? null : [ IMAGESIZE , IMAGESIZE ] ;
289
+ const imageResize = ( imgToPredict === this . video ) ? null : [ IMAGE_SIZE , IMAGE_SIZE ] ;
287
290
const processedImg = imgToTensor ( imgToPredict , imageResize ) ;
288
291
const activation = this . mobilenetFeatures . predict ( processedImg ) ;
289
292
const predictions = this . customModel . predict ( activation ) ;
@@ -294,33 +297,35 @@ class Mobilenet {
294
297
return prediction [ 0 ] ;
295
298
}
296
299
297
- // Static Method: get top k classes for mobilenet
298
- static async getTopKClasses ( logits , topK , callback = ( ) => { } ) {
299
- const values = await logits . data ( ) ;
300
- const valuesAndIndices = [ ] ;
301
- for ( let i = 0 ; i < values . length ; i += 1 ) {
302
- valuesAndIndices . push ( { value : values [ i ] , index : i } ) ;
300
+ async load ( filesOrPath = null , callback ) {
301
+ if ( typeof filesOrPath !== 'string' ) {
302
+ let model = null ;
303
+ let weights = null ;
304
+ Array . from ( filesOrPath ) . forEach ( ( file ) => {
305
+ if ( file . name . includes ( '.json' ) ) {
306
+ model = file ;
307
+ } else if ( file . name . includes ( '.bin' ) ) {
308
+ weights = file ;
309
+ }
310
+ } ) ;
311
+ this . customModel = await tf . loadModel ( tf . io . browserFiles ( [ model , weights ] ) ) ;
312
+ } else {
313
+ this . customModel = await tf . loadModel ( filesOrPath ) ;
314
+ }
315
+ if ( callback ) {
316
+ callback ( ) ;
303
317
}
304
- valuesAndIndices . sort ( ( a , b ) => b . value - a . value ) ;
305
- const topkValues = new Float32Array ( topK ) ;
318
+ return this . customModel ;
319
+ }
306
320
307
- const topkIndices = new Int32Array ( topK ) ;
308
- for ( let i = 0 ; i < topK ; i += 1 ) {
309
- topkValues [ i ] = valuesAndIndices [ i ] . value ;
310
- topkIndices [ i ] = valuesAndIndices [ i ] . index ;
321
+ async save ( destination = 'downloads://' , callback ) {
322
+ if ( ! this . customModel ) {
323
+ throw new Error ( 'No model found.' ) ;
311
324
}
312
- const topClassesAndProbs = [ ] ;
313
- for ( let i = 0 ; i < topkIndices . length ; i += 1 ) {
314
- topClassesAndProbs . push ( {
315
- className : IMAGENET_CLASSES [ topkIndices [ i ] ] ,
316
- probability : topkValues [ i ] ,
317
- } ) ;
325
+ await this . customModel . model . save ( destination ) ;
326
+ if ( callback ) {
327
+ callback ( ) ;
318
328
}
319
-
320
- await tf . nextFrame ( ) ;
321
-
322
- callback ( undefined , topClassesAndProbs ) ;
323
- return topClassesAndProbs ;
324
329
}
325
330
}
326
331
0 commit comments