Skip to content

Commit f471eec

Browse files
authored
Load and save a custom model created with FeatureExtractor (#219)
* load and save feature extractor model * update learning rate * add string an callback
1 parent 6d18b81 commit f471eec

File tree

4 files changed

+151
-41
lines changed

4 files changed

+151
-41
lines changed

dist/ml5.min.js

Lines changed: 106 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/ml5.min.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/FeatureExtractor/Mobilenet.js

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,53 @@ A class that extract features from Mobilenet
88
*/
99

1010
import * as tf from '@tensorflow/tfjs';
11+
import * as mobilenet from '@tensorflow-models/mobilenet';
12+
1113
import Video from './../utils/Video';
12-
import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES';
14+
1315
import { imgToTensor } from '../utils/imageUtilities';
1416
import callCallback from '../utils/callcallback';
1517

16-
const IMAGESIZE = 224;
18+
const IMAGE_SIZE = 224;
1719
const DEFAULTS = {
1820
version: 1,
19-
alpha: 1.0,
21+
alpha: 0.25,
2022
topk: 3,
2123
learningRate: 0.0001,
2224
hiddenUnits: 100,
2325
epochs: 20,
2426
numClasses: 2,
2527
batchSize: 0.4,
28+
layer: 'conv_pw_13_relu',
2629
};
2730

2831
class Mobilenet {
2932
constructor(options, callback) {
30-
this.mobilenet = null;
31-
this.modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
33+
this.mobilenet = mobilenet;
3234
this.topKPredictions = 10;
3335
this.hasAnyTrainedClass = false;
3436
this.customModel = null;
3537
this.epochs = options.epochs || DEFAULTS.epochs;
38+
this.version = options.version || DEFAULTS.version;
3639
this.hiddenUnits = options.hiddenUnits || DEFAULTS.hiddenUnits;
3740
this.numClasses = options.numClasses || DEFAULTS.numClasses;
3841
this.learningRate = options.learningRate || DEFAULTS.learningRate;
3942
this.batchSize = options.batchSize || DEFAULTS.batchSize;
43+
this.layer = options.layer || DEFAULTS.layer;
44+
this.alpha = options.alpha || DEFAULTS.alpha;
4045
this.isPredicting = false;
4146
this.mapStringToIndex = [];
4247
this.usageType = null;
4348
this.ready = callCallback(this.loadModel(), callback);
44-
// this.then = this.ready.then;
4549
}
4650

4751
async loadModel() {
48-
this.mobilenet = await tf.loadModel(this.modelPath);
49-
const layer = this.mobilenet.getLayer('conv_pw_13_relu');
52+
this.mobilenet = await this.mobilenet.load(this.version, this.alpha);
53+
const layer = this.mobilenet.model.getLayer(this.layer);
54+
this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.model.inputs, outputs: layer.output });
5055
if (this.video) {
51-
tf.tidy(() => this.mobilenet.predict(imgToTensor(this.video))); // Warm up
56+
await this.mobilenet.classify(imgToTensor(this.video)); // Warm up
5257
}
53-
this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
5458
return this;
5559
}
5660

@@ -80,7 +84,7 @@ class Mobilenet {
8084
}
8185

8286
if (inputVideo) {
83-
const vid = new Video(inputVideo, IMAGESIZE);
87+
const vid = new Video(inputVideo, IMAGE_SIZE);
8488
this.video = await vid.loadVideo();
8589
}
8690

@@ -121,10 +125,9 @@ class Mobilenet {
121125
async addImageInternal(imgToAdd, label) {
122126
await this.ready;
123127
tf.tidy(() => {
124-
const imageResize = (imgToAdd === this.video) ? null : [IMAGESIZE, IMAGESIZE];
128+
const imageResize = (imgToAdd === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
125129
const processedImg = imgToTensor(imgToAdd, imageResize);
126130
const prediction = this.mobilenetFeatures.predict(processedImg);
127-
128131
let y;
129132
if (this.usageType === 'classifier') {
130133
y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.numClasses));
@@ -244,7 +247,7 @@ class Mobilenet {
244247
await tf.nextFrame();
245248
this.isPredicting = true;
246249
const predictedClass = tf.tidy(() => {
247-
const imageResize = (imgToPredict === this.video) ? null : [IMAGESIZE, IMAGESIZE];
250+
const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
248251
const processedImg = imgToTensor(imgToPredict, imageResize);
249252
const activation = this.mobilenetFeatures.predict(processedImg);
250253
const predictions = this.customModel.predict(activation);
@@ -283,7 +286,7 @@ class Mobilenet {
283286
await tf.nextFrame();
284287
this.isPredicting = true;
285288
const predictedClass = tf.tidy(() => {
286-
const imageResize = (imgToPredict === this.video) ? null : [IMAGESIZE, IMAGESIZE];
289+
const imageResize = (imgToPredict === this.video) ? null : [IMAGE_SIZE, IMAGE_SIZE];
287290
const processedImg = imgToTensor(imgToPredict, imageResize);
288291
const activation = this.mobilenetFeatures.predict(processedImg);
289292
const predictions = this.customModel.predict(activation);
@@ -294,33 +297,35 @@ class Mobilenet {
294297
return prediction[0];
295298
}
296299

297-
// Static Method: get top k classes for mobilenet
298-
static async getTopKClasses(logits, topK, callback = () => {}) {
299-
const values = await logits.data();
300-
const valuesAndIndices = [];
301-
for (let i = 0; i < values.length; i += 1) {
302-
valuesAndIndices.push({ value: values[i], index: i });
300+
async load(filesOrPath = null, callback) {
301+
if (typeof filesOrPath !== 'string') {
302+
let model = null;
303+
let weights = null;
304+
Array.from(filesOrPath).forEach((file) => {
305+
if (file.name.includes('.json')) {
306+
model = file;
307+
} else if (file.name.includes('.bin')) {
308+
weights = file;
309+
}
310+
});
311+
this.customModel = await tf.loadModel(tf.io.browserFiles([model, weights]));
312+
} else {
313+
this.customModel = await tf.loadModel(filesOrPath);
314+
}
315+
if (callback) {
316+
callback();
303317
}
304-
valuesAndIndices.sort((a, b) => b.value - a.value);
305-
const topkValues = new Float32Array(topK);
318+
return this.customModel;
319+
}
306320

307-
const topkIndices = new Int32Array(topK);
308-
for (let i = 0; i < topK; i += 1) {
309-
topkValues[i] = valuesAndIndices[i].value;
310-
topkIndices[i] = valuesAndIndices[i].index;
321+
async save(destination = 'downloads://', callback) {
322+
if (!this.customModel) {
323+
throw new Error('No model found.');
311324
}
312-
const topClassesAndProbs = [];
313-
for (let i = 0; i < topkIndices.length; i += 1) {
314-
topClassesAndProbs.push({
315-
className: IMAGENET_CLASSES[topkIndices[i]],
316-
probability: topkValues[i],
317-
});
325+
await this.customModel.model.save(destination);
326+
if (callback) {
327+
callback();
318328
}
319-
320-
await tf.nextFrame();
321-
322-
callback(undefined, topClassesAndProbs);
323-
return topClassesAndProbs;
324329
}
325330
}
326331

src/FeatureExtractor/index.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ const featureExtractor = (model, optionsOrCallback, cb) => {
2121
let options = {};
2222
let callback = cb;
2323

24-
if (optionsOrCallback === 'object') {
24+
if (typeof optionsOrCallback === 'object') {
2525
options = optionsOrCallback;
2626
} else if (typeof optionsOrCallback === 'function') {
2727
callback = optionsOrCallback;

0 commit comments

Comments
 (0)