@@ -13,6 +13,7 @@ import * as mobilenet from '@tensorflow-models/mobilenet';
13
13
import Video from './../utils/Video' ;
14
14
15
15
import { imgToTensor } from '../utils/imageUtilities' ;
16
+ import { saveBlob } from '../utils/io' ;
16
17
import callCallback from '../utils/callcallback' ;
17
18
18
19
const IMAGE_SIZE = 224 ;
@@ -304,28 +305,49 @@ class Mobilenet {
304
305
Array . from ( filesOrPath ) . forEach ( ( file ) => {
305
306
if ( file . name . includes ( '.json' ) ) {
306
307
model = file ;
308
+ const fr = new FileReader ( ) ;
309
+ fr . onload = ( d ) => {
310
+ this . mapStringToIndex = JSON . parse ( d . target . result ) . ml5Specs . mapStringToIndex ;
311
+ } ;
312
+ fr . readAsText ( file ) ;
307
313
} else if ( file . name . includes ( '.bin' ) ) {
308
314
weights = file ;
309
315
}
310
316
} ) ;
311
317
this . customModel = await tf . loadModel ( tf . io . browserFiles ( [ model , weights ] ) ) ;
312
318
} else {
319
+ fetch ( filesOrPath )
320
+ . then ( r => r . json ( ) )
321
+ . then ( ( r ) => { this . mapStringToIndex = r . ml5Specs . mapStringToIndex ; } ) ;
313
322
this . customModel = await tf . loadModel ( filesOrPath ) ;
314
- }
315
- if ( callback ) {
316
- callback ( ) ;
323
+ if ( callback ) {
324
+ callback ( ) ;
325
+ }
317
326
}
318
327
return this . customModel ;
319
328
}
320
329
321
- async save ( destination = 'downloads://' , callback ) {
330
+ async save ( callback ) {
322
331
if ( ! this . customModel ) {
323
332
throw new Error ( 'No model found.' ) ;
324
333
}
325
- await this . customModel . model . save ( destination ) ;
326
- if ( callback ) {
327
- callback ( ) ;
328
- }
334
+ this . customModel . save ( tf . io . withSaveHandler ( async ( data ) => {
335
+ this . weightsManifest = {
336
+ modelTopology : data . modelTopology ,
337
+ weightsManifest : [ {
338
+ paths : [ './model.weights.bin' ] ,
339
+ weights : data . weightSpecs ,
340
+ } ] ,
341
+ ml5Specs : {
342
+ mapStringToIndex : this . mapStringToIndex ,
343
+ } ,
344
+ } ;
345
+ await saveBlob ( data . weightData , 'model.weights.bin' , 'application/octet-stream' ) ;
346
+ await saveBlob ( JSON . stringify ( this . weightsManifest ) , 'model.json' , 'text/plain' ) ;
347
+ if ( callback ) {
348
+ callback ( ) ;
349
+ }
350
+ } ) ) ;
329
351
}
330
352
331
353
infer ( input , endpoint ) {
0 commit comments