Skip to content

Commit 3dac4fa

Browse files
committed
Merge branch 'image-classifier-promises' of https://github.com/tmcw/ml5-library into tmcw-image-classifier-promises
2 parents 82843f8 + 3997e10 commit 3dac4fa

File tree

22 files changed

+719
-548
lines changed

22 files changed

+719
-548
lines changed

karma.conf.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ module.exports = (config) => {
5959
os_version: 'High Sierra'
6060
},
6161
},
62-
reporters: ['progress'],
62+
reporters: ['mocha'],
6363
port: 9876,
6464
colors: true,
6565
logLevel: config.LOG_INFO,

package.json

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@
5252
"karma-browserstack-launcher": "~1.3.0",
5353
"karma-chrome-launcher": "2.2.0",
5454
"karma-jasmine": "1.1.1",
55+
"karma-mocha-reporter": "^2.2.5",
56+
"karma-safari-launcher": "1.0.0",
5557
"karma-webpack": "3.0.0",
5658
"npm-run-all": "4.1.2",
5759
"regenerator-runtime": "0.11.1",
@@ -86,6 +88,7 @@
8688
"dependencies": {
8789
"@tensorflow-models/mobilenet": "0.1.0",
8890
"@tensorflow-models/posenet": "^0.1.2",
89-
"@tensorflow/tfjs": "0.11.4"
91+
"@tensorflow/tfjs": "0.11.4",
92+
"events": "^3.0.0"
9093
}
9194
}

src/FeatureExtractor/Mobilenet.js

Lines changed: 109 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import * as tf from '@tensorflow/tfjs';
1111
import Video from './../utils/Video';
1212
import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES';
1313
import { imgToTensor } from '../utils/imageUtilities';
14+
import callCallback from '../utils/callcallback';
1415

1516
const IMAGESIZE = 224;
1617
const DEFAULTS = {
@@ -29,7 +30,6 @@ class Mobilenet {
2930
this.mobilenet = null;
3031
this.modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
3132
this.topKPredictions = 10;
32-
this.modelLoaded = false;
3333
this.hasAnyTrainedClass = false;
3434
this.customModel = null;
3535
this.epochs = options.epochs || DEFAULTS.epochs;
@@ -40,12 +40,8 @@ class Mobilenet {
4040
this.isPredicting = false;
4141
this.mapStringToIndex = [];
4242
this.usageType = null;
43-
44-
this.loadModel().then((net) => {
45-
this.modelLoaded = true;
46-
this.mobilenetFeatures = net;
47-
callback();
48-
});
43+
this.ready = callCallback(this.loadModel(), callback);
44+
this.then = this.ready.then;
4945
}
5046

5147
async loadModel() {
@@ -54,20 +50,21 @@ class Mobilenet {
5450
if (this.video) {
5551
tf.tidy(() => this.mobilenet.predict(imgToTensor(this.video))); // Warm up
5652
}
57-
return tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
53+
this.mobilenetFeatures = await tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
54+
return this;
5855
}
5956

6057
classification(video, callback) {
6158
this.usageType = 'classifier';
62-
return this.loadVideo(video, callback);
59+
return callCallback(this.loadVideo(video), callback);
6360
}
6461

6562
regression(video, callback) {
6663
this.usageType = 'regressor';
67-
return this.loadVideo(video, callback);
64+
return callCallback(this.loadVideo(video), callback);
6865
}
6966

70-
loadVideo(video, callback = () => {}) {
67+
async loadVideo(video) {
7168
let inputVideo = null;
7269

7370
if (video instanceof HTMLVideoElement) {
@@ -78,16 +75,13 @@ class Mobilenet {
7875

7976
if (inputVideo) {
8077
const vid = new Video(inputVideo, IMAGESIZE);
81-
vid.loadVideo().then(async () => {
82-
this.video = vid.video;
83-
callback();
84-
});
78+
this.video = await vid.loadVideo();
8579
}
8680

8781
return this;
8882
}
8983

90-
addImage(inputOrLabel, labelOrCallback, cb = () => {}) {
84+
async addImage(inputOrLabel, labelOrCallback, cb) {
9185
let imgToAdd;
9286
let label;
9387
let callback = cb;
@@ -115,38 +109,37 @@ class Mobilenet {
115109
}
116110
}
117111

118-
if (this.modelLoaded) {
119-
tf.tidy(() => {
120-
const processedImg = imgToTensor(imgToAdd);
121-
const prediction = this.mobilenetFeatures.predict(processedImg);
122-
123-
let y;
124-
if (this.usageType === 'classifier') {
125-
y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.numClasses));
126-
} else if (this.usageType === 'regressor') {
127-
y = tf.tidy(() => tf.tensor2d([[label]]));
128-
}
129-
130-
if (this.xs == null) {
131-
this.xs = tf.keep(prediction);
132-
this.ys = tf.keep(y);
133-
this.hasAnyTrainedClass = true;
134-
} else {
135-
const oldX = this.xs;
136-
this.xs = tf.keep(oldX.concat(prediction, 0));
137-
const oldY = this.ys;
138-
this.ys = tf.keep(oldY.concat(y, 0));
139-
oldX.dispose();
140-
oldY.dispose();
141-
y.dispose();
142-
}
143-
});
144-
if (callback) {
145-
callback();
112+
return callCallback(this.addImageInternal(imgToAdd, label), callback);
113+
}
114+
115+
async addImageInternal(imgToAdd, label) {
116+
await this.ready;
117+
tf.tidy(() => {
118+
const processedImg = imgToTensor(imgToAdd);
119+
const prediction = this.mobilenetFeatures.predict(processedImg);
120+
121+
let y;
122+
if (this.usageType === 'classifier') {
123+
y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.numClasses));
124+
} else if (this.usageType === 'regressor') {
125+
y = tf.tensor2d([[label]]);
146126
}
147-
} else {
148-
console.warn('The model is not loaded yet.');
149-
}
127+
128+
if (this.xs == null) {
129+
this.xs = tf.keep(prediction);
130+
this.ys = tf.keep(y);
131+
this.hasAnyTrainedClass = true;
132+
} else {
133+
const oldX = this.xs;
134+
this.xs = tf.keep(oldX.concat(prediction, 0));
135+
const oldY = this.ys;
136+
this.ys = tf.keep(oldY.concat(y, 0));
137+
oldX.dispose();
138+
oldY.dispose();
139+
y.dispose();
140+
}
141+
});
142+
return this;
150143
}
151144

152145
async train(onProgress) {
@@ -203,7 +196,7 @@ class Mobilenet {
203196
throw new Error('Batch size is 0 or NaN. Please choose a non-zero fraction.');
204197
}
205198

206-
this.customModel.fit(this.xs, this.ys, {
199+
return this.customModel.fit(this.xs, this.ys, {
207200
batchSize,
208201
epochs: this.epochs,
209202
callbacks: {
@@ -217,83 +210,85 @@ class Mobilenet {
217210
}
218211

219212
/* eslint max-len: ["error", { "code": 180 }] */
220-
async classify(inputOrCallback, cb = null) {
221-
if (this.usageType === 'classifier') {
222-
let imgToPredict;
223-
let callback;
224-
225-
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
226-
imgToPredict = inputOrCallback;
227-
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
228-
imgToPredict = inputOrCallback.elt; // p5.js image element
229-
} else if (typeof inputOrCallback === 'function') {
230-
imgToPredict = this.video;
231-
callback = inputOrCallback;
232-
}
213+
async classify(inputOrCallback, cb) {
214+
let imgToPredict;
215+
let callback;
216+
217+
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
218+
imgToPredict = inputOrCallback;
219+
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
220+
imgToPredict = inputOrCallback.elt; // p5.js image element
221+
} else if (typeof inputOrCallback === 'function') {
222+
imgToPredict = this.video;
223+
callback = inputOrCallback;
224+
}
233225

234-
if (typeof cb === 'function') {
235-
callback = cb;
236-
}
226+
if (typeof cb === 'function') {
227+
callback = cb;
228+
}
237229

238-
this.isPredicting = true;
239-
const predictedClass = tf.tidy(() => {
240-
const processedImg = imgToTensor(imgToPredict);
241-
const activation = this.mobilenetFeatures.predict(processedImg);
242-
const predictions = this.customModel.predict(activation);
243-
return predictions.as1D().argMax();
244-
});
245-
let classId = (await predictedClass.data())[0];
246-
await tf.nextFrame();
247-
if (callback) {
248-
if (this.mapStringToIndex.length > 0) {
249-
classId = this.mapStringToIndex[classId];
250-
}
251-
callback(classId);
252-
}
253-
} else {
254-
console.warn('Mobilenet Feature Extraction has not been set to be a classifier.');
230+
return callCallback(this.classifyInternal(imgToPredict), callback);
231+
}
232+
233+
async classifyInternal(imgToPredict) {
234+
if (this.usageType === 'classifier') {
235+
throw new Error('Mobilenet Feature Extraction has not been set to be a classifier.');
236+
}
237+
238+
this.isPredicting = true;
239+
const predictedClass = tf.tidy(() => {
240+
const processedImg = imgToTensor(imgToPredict);
241+
const activation = this.mobilenetFeatures.predict(processedImg);
242+
const predictions = this.customModel.predict(activation);
243+
return predictions.as1D().argMax();
244+
});
245+
let classId = (await predictedClass.data())[0];
246+
await tf.nextFrame();
247+
if (this.mapStringToIndex.length > 0) {
248+
classId = this.mapStringToIndex[classId];
255249
}
250+
return classId;
256251
}
257252

258253
/* eslint max-len: ["error", { "code": 180 }] */
259-
async predict(inputOrCallback, cb = null) {
260-
if (this.usageType === 'regressor') {
261-
let imgToPredict;
262-
let callback;
263-
264-
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
265-
imgToPredict = inputOrCallback;
266-
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
267-
imgToPredict = inputOrCallback.elt; // p5.js image element
268-
} else if (typeof inputOrCallback === 'function') {
269-
imgToPredict = this.video;
270-
callback = inputOrCallback;
271-
}
254+
async predict(inputOrCallback, cb) {
255+
let imgToPredict;
256+
let callback;
257+
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
258+
imgToPredict = inputOrCallback;
259+
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
260+
imgToPredict = inputOrCallback.elt; // p5.js image element
261+
} else if (typeof inputOrCallback === 'function') {
262+
imgToPredict = this.video;
263+
callback = inputOrCallback;
264+
}
272265

273-
if (typeof cb === 'function') {
274-
callback = cb;
275-
}
266+
if (typeof cb === 'function') {
267+
callback = cb;
268+
}
269+
return callCallback(this.predictInternal(imgToPredict), callback);
270+
}
276271

277-
this.isPredicting = true;
278-
const predictedClass = tf.tidy(() => {
279-
const processedImg = imgToTensor(imgToPredict);
280-
const activation = this.mobilenetFeatures.predict(processedImg);
281-
const predictions = this.customModel.predict(activation);
282-
return predictions.as1D();
283-
});
284-
const prediction = (await predictedClass.data());
285-
predictedClass.dispose();
286-
await tf.nextFrame();
287-
if (callback) {
288-
callback(prediction[0]);
289-
}
290-
} else {
291-
console.warn('Mobilenet Feature Extraction has not been set to be a regressor.');
272+
async predictInternal(imgToPredict) {
273+
if (this.usageType !== 'regressor') {
274+
throw new Error('Mobilenet Feature Extraction has not been set to be a regressor.');
292275
}
276+
277+
this.isPredicting = true;
278+
const predictedClass = tf.tidy(() => {
279+
const processedImg = imgToTensor(imgToPredict);
280+
const activation = this.mobilenetFeatures.predict(processedImg);
281+
const predictions = this.customModel.predict(activation);
282+
return predictions.as1D();
283+
});
284+
const prediction = await predictedClass.data();
285+
predictedClass.dispose();
286+
await tf.nextFrame();
287+
return prediction[0];
293288
}
294289

295290
// Static Method: get top k classes for mobilenet
296-
static async getTopKClasses(logits, topK, callback) {
291+
static async getTopKClasses(logits, topK, callback = () => {}) {
297292
const values = await logits.data();
298293
const valuesAndIndices = [];
299294
for (let i = 0; i < values.length; i += 1) {
@@ -317,9 +312,7 @@ class Mobilenet {
317312

318313
await tf.nextFrame();
319314

320-
if (callback) {
321-
callback(topClassesAndProbs);
322-
}
315+
callback(undefined, topClassesAndProbs);
323316
return topClassesAndProbs;
324317
}
325318
}

src/FeatureExtractor/index.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ General Feature Extractor Manager
1010
import Mobilenet from './Mobilenet';
1111

1212
/* eslint max-len: ["error", { "code": 180 }] */
13-
const featureExtractor = (model, optionsOrCallback, cb = () => {}) => {
13+
const featureExtractor = (model, optionsOrCallback, cb) => {
1414
let modelName;
1515
if (typeof model !== 'string') {
1616
throw new Error('Please specify a model to use. E.g: "MobileNet"');

0 commit comments

Comments
 (0)