Skip to content

Commit 7b3af68

Browse files
authored
saving feature extrctor model with ml5Specs (#233)
1 parent 148b7d6 commit 7b3af68

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

src/FeatureExtractor/Mobilenet.js

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import * as mobilenet from '@tensorflow-models/mobilenet';
1313
import Video from './../utils/Video';
1414

1515
import { imgToTensor } from '../utils/imageUtilities';
16+
import { saveBlob } from '../utils/io';
1617
import callCallback from '../utils/callcallback';
1718

1819
const IMAGE_SIZE = 224;
@@ -304,28 +305,49 @@ class Mobilenet {
304305
Array.from(filesOrPath).forEach((file) => {
305306
if (file.name.includes('.json')) {
306307
model = file;
308+
const fr = new FileReader();
309+
fr.onload = (d) => {
310+
this.mapStringToIndex = JSON.parse(d.target.result).ml5Specs.mapStringToIndex;
311+
};
312+
fr.readAsText(file);
307313
} else if (file.name.includes('.bin')) {
308314
weights = file;
309315
}
310316
});
311317
this.customModel = await tf.loadModel(tf.io.browserFiles([model, weights]));
312318
} else {
319+
fetch(filesOrPath)
320+
.then(r => r.json())
321+
.then((r) => { this.mapStringToIndex = r.ml5Specs.mapStringToIndex; });
313322
this.customModel = await tf.loadModel(filesOrPath);
314-
}
315-
if (callback) {
316-
callback();
323+
if (callback) {
324+
callback();
325+
}
317326
}
318327
return this.customModel;
319328
}
320329

321-
async save(destination = 'downloads://', callback) {
330+
async save(callback) {
322331
if (!this.customModel) {
323332
throw new Error('No model found.');
324333
}
325-
await this.customModel.model.save(destination);
326-
if (callback) {
327-
callback();
328-
}
334+
this.customModel.save(tf.io.withSaveHandler(async (data) => {
335+
this.weightsManifest = {
336+
modelTopology: data.modelTopology,
337+
weightsManifest: [{
338+
paths: ['./model.weights.bin'],
339+
weights: data.weightSpecs,
340+
}],
341+
ml5Specs: {
342+
mapStringToIndex: this.mapStringToIndex,
343+
},
344+
};
345+
await saveBlob(data.weightData, 'model.weights.bin', 'application/octet-stream');
346+
await saveBlob(JSON.stringify(this.weightsManifest), 'model.json', 'text/plain');
347+
if (callback) {
348+
callback();
349+
}
350+
}));
329351
}
330352

331353
infer(input, endpoint) {

src/utils/io.js

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ const saveFile = (name, data) => {
1414
document.body.removeChild(downloadElt);
1515
};
1616

17+
const saveBlob = async (data, name, type) => {
18+
const link = document.createElement('a');
19+
link.style.display = 'none';
20+
document.body.appendChild(link);
21+
const blob = new Blob([data], { type });
22+
link.href = URL.createObjectURL(blob);
23+
link.download = name;
24+
link.click();
25+
};
26+
1727
const loadFile = async (path, callback) => fetch(path)
1828
.then(response => response.json())
1929
.then((json) => {
@@ -28,5 +38,6 @@ const loadFile = async (path, callback) => fetch(path)
2838

2939
export {
3040
saveFile,
41+
saveBlob,
3142
loadFile,
3243
};

0 commit comments

Comments
 (0)