Skip to content

Commit e92411f

Browse files
committed
Create Image regressor and classifier with mobilenet
1 parent 7440c02 commit e92411f

File tree

4 files changed

+343
-47
lines changed

4 files changed

+343
-47
lines changed

src/FeatureExtractor/Mobilenet.js

Lines changed: 316 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,324 @@
44
// https://opensource.org/licenses/MIT
55

66
/*
7-
A General Feature Extractor class
7+
A class that extract features from Mobilenet
88
*/
99

10-
class FeatureExtractor extends Video {
11-
constructor(model, videoOrCallback, optionsOrCallback = {}, cb = () => {}) {
12-
super(video, IMAGESIZE);
10+
import * as tf from '@tensorflow/tfjs';
11+
import Video from './../utils/Video';
12+
import { IMAGENET_CLASSES } from './../utils/IMAGENET_CLASSES';
13+
import { imgToTensor } from '../utils/imageUtilities';
14+
15+
const IMAGESIZE = 224;
16+
const DEFAULTS = {
17+
version: 1,
18+
alpha: 1.0,
19+
topk: 3,
20+
learningRate: 0.0001,
21+
hiddenUnits: 100,
22+
epochs: 20,
23+
numClasses: 2,
24+
batchSize: 0.4,
25+
};
26+
27+
class Mobilenet {
28+
constructor(options, callback) {
29+
this.mobilenet = null;
30+
this.modelPath = 'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json';
31+
this.topKPredictions = 10;
32+
this.modelLoaded = false;
33+
this.hasAnyTrainedClass = false;
34+
this.customModel = null;
35+
this.epochs = options.epochs || DEFAULTS.epochs;
36+
this.hiddenUnits = options.hiddenUnits || DEFAULTS.hiddenUnits;
37+
this.numClasses = options.numClasses || DEFAULTS.numClasses;
38+
this.learningRate = options.learningRate || DEFAULTS.learningRate;
39+
this.batchSize = options.batchSize || DEFAULTS.batchSize;
40+
this.isPredicting = false;
41+
this.mapStringToIndex = [];
42+
this.usageType = null;
43+
44+
this.loadModel().then((net) => {
45+
this.modelLoaded = true;
46+
this.mobilenetFeatures = net;
47+
callback();
48+
});
49+
}
50+
51+
async loadModel() {
52+
this.mobilenet = await tf.loadModel(this.modelPath);
53+
const layer = this.mobilenet.getLayer('conv_pw_13_relu');
54+
if (this.video) {
55+
tf.tidy(() => this.mobilenet.predict(imgToTensor(this.video))); // Warm up
56+
}
57+
return tf.model({ inputs: this.mobilenet.inputs, outputs: layer.output });
58+
}
59+
60+
asClassifier(video, callback) {
61+
this.usageType = 'classifier';
62+
return this.loadVideo(video, callback);
63+
}
64+
65+
asRegressor(video, callback) {
66+
this.usageType = 'regressor';
67+
return this.loadVideo(video, callback);
68+
}
69+
70+
loadVideo(video, callback = () => {}) {
71+
let inputVideo = null;
72+
73+
if (video instanceof HTMLVideoElement) {
74+
inputVideo = video;
75+
} else if (typeof video === 'object' && video.elt instanceof HTMLVideoElement) {
76+
inputVideo = video.elt;
77+
}
78+
79+
if (inputVideo) {
80+
const vid = new Video(inputVideo, IMAGESIZE);
81+
vid.loadVideo().then(async () => {
82+
this.video = vid.video;
83+
callback();
84+
});
85+
}
86+
87+
return this;
88+
}
89+
90+
addImage(inputOrLabel, labelOrCallback, cb = () => {}) {
91+
let imgToAdd;
92+
let label;
93+
let callback = cb;
94+
95+
if (inputOrLabel instanceof HTMLImageElement || inputOrLabel instanceof HTMLVideoElement) {
96+
imgToAdd = inputOrLabel;
97+
} else if (typeof inputOrLabel === 'object' && (inputOrLabel.elt instanceof HTMLImageElement || inputOrLabel.elt instanceof HTMLVideoElement)) {
98+
imgToAdd = inputOrLabel;
99+
} else if (typeof inputOrLabel === 'string' || typeof inputOrLabel === 'number') {
100+
imgToAdd = this.video;
101+
label = inputOrLabel;
102+
}
103+
104+
if (typeof labelOrCallback === 'string' || typeof labelOrCallback === 'number') {
105+
label = labelOrCallback;
106+
} else if (typeof labelOrCallback === 'function') {
107+
callback = labelOrCallback;
108+
}
109+
110+
if (typeof label === 'string') {
111+
if (!this.mapStringToIndex.includes(label)) {
112+
label = this.mapStringToIndex.push(label) - 1;
113+
} else {
114+
label = this.mapStringToIndex.indexOf(label);
115+
}
116+
}
117+
118+
if (this.modelLoaded) {
119+
tf.tidy(() => {
120+
const processedImg = imgToTensor(imgToAdd);
121+
const prediction = this.mobilenetFeatures.predict(processedImg);
122+
123+
let y;
124+
if (this.usageType === 'classifier') {
125+
y = tf.tidy(() => tf.oneHot(tf.tensor1d([label], 'int32'), this.numClasses));
126+
} else if (this.usageType === 'regressor') {
127+
y = tf.tidy(() => tf.tensor2d([[label]]));
128+
}
129+
130+
if (this.xs == null) {
131+
this.xs = tf.keep(prediction);
132+
this.ys = tf.keep(y);
133+
this.hasAnyTrainedClass = true;
134+
} else {
135+
const oldX = this.xs;
136+
this.xs = tf.keep(oldX.concat(prediction, 0));
137+
const oldY = this.ys;
138+
this.ys = tf.keep(oldY.concat(y, 0));
139+
oldX.dispose();
140+
oldY.dispose();
141+
y.dispose();
142+
}
143+
});
144+
if (callback) {
145+
callback();
146+
}
147+
} else {
148+
console.warn('The model is not loaded yet.');
149+
}
150+
}
151+
152+
async train(onProgress) {
153+
if (!this.hasAnyTrainedClass) {
154+
throw new Error('Add some examples before training!');
155+
}
156+
157+
this.isPredicting = false;
158+
159+
if (this.usageType === 'classifier') {
160+
this.loss = 'categoricalCrossentropy';
161+
this.customModel = tf.sequential({
162+
layers: [
163+
tf.layers.flatten({ inputShape: [7, 7, 256] }),
164+
tf.layers.dense({
165+
units: this.hiddenUnits,
166+
activation: 'relu',
167+
kernelInitializer: 'varianceScaling',
168+
useBias: true,
169+
}),
170+
tf.layers.dense({
171+
units: this.numClasses,
172+
kernelInitializer: 'varianceScaling',
173+
useBias: false,
174+
activation: 'softmax',
175+
}),
176+
],
177+
});
178+
} else if (this.usageType === 'regressor') {
179+
this.loss = 'meanSquaredError';
180+
this.customModel = tf.sequential({
181+
layers: [
182+
tf.layers.flatten({ inputShape: [7, 7, 256] }),
183+
tf.layers.dense({
184+
units: this.hiddenUnits,
185+
activation: 'relu',
186+
kernelInitializer: 'varianceScaling',
187+
useBias: true,
188+
}),
189+
tf.layers.dense({
190+
units: 1,
191+
useBias: false,
192+
kernelInitializer: 'Zeros',
193+
activation: 'linear',
194+
}),
195+
],
196+
});
197+
}
198+
199+
const optimizer = tf.train.adam(this.learningRate);
200+
this.customModel.compile({ optimizer, loss: this.loss });
201+
const batchSize = Math.floor(this.xs.shape[0] * this.batchSize);
202+
if (!(batchSize > 0)) {
203+
throw new Error('Batch size is 0 or NaN. Please choose a non-zero fraction.');
204+
}
205+
206+
this.customModel.fit(this.xs, this.ys, {
207+
batchSize,
208+
epochs: this.epochs,
209+
callbacks: {
210+
onBatchEnd: async (batch, logs) => {
211+
onProgress(logs.loss.toFixed(5));
212+
await tf.nextFrame();
213+
},
214+
onTrainEnd: () => onProgress(null),
215+
},
216+
});
217+
}
218+
219+
/* eslint max-len: ["error", { "code": 180 }] */
220+
async classify(inputOrCallback, cb = null) {
221+
if (this.usageType === 'classifier') {
222+
let imgToPredict;
223+
let callback;
224+
225+
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
226+
imgToPredict = inputOrCallback;
227+
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
228+
imgToPredict = inputOrCallback.elt; // p5.js image element
229+
} else if (typeof inputOrCallback === 'function') {
230+
imgToPredict = this.video;
231+
callback = inputOrCallback;
232+
}
233+
234+
if (typeof cb === 'function') {
235+
callback = cb;
236+
}
237+
238+
this.isPredicting = true;
239+
const predictedClass = tf.tidy(() => {
240+
const processedImg = imgToTensor(imgToPredict);
241+
const activation = this.mobilenetFeatures.predict(processedImg);
242+
const predictions = this.customModel.predict(activation);
243+
return predictions.as1D().argMax();
244+
});
245+
let classId = (await predictedClass.data())[0];
246+
await tf.nextFrame();
247+
if (callback) {
248+
if (this.mapStringToIndex.length > 0) {
249+
classId = this.mapStringToIndex[classId];
250+
}
251+
callback(classId);
252+
}
253+
} else {
254+
console.warn('Mobilenet Feature Extraction has not been set to be a classifier.');
255+
}
256+
}
257+
258+
/* eslint max-len: ["error", { "code": 180 }] */
259+
async predict(inputOrCallback, cb = null) {
260+
if (this.usageType === 'regressor') {
261+
let imgToPredict;
262+
let callback;
263+
264+
if (inputOrCallback instanceof HTMLImageElement || inputOrCallback instanceof HTMLVideoElement) {
265+
imgToPredict = inputOrCallback;
266+
} else if (typeof inputOrCallback === 'object' && (inputOrCallback.elt instanceof HTMLImageElement || inputOrCallback.elt instanceof HTMLVideoElement)) {
267+
imgToPredict = inputOrCallback.elt; // p5.js image element
268+
} else if (typeof inputOrCallback === 'function') {
269+
imgToPredict = this.video;
270+
callback = inputOrCallback;
271+
}
272+
273+
if (typeof cb === 'function') {
274+
callback = cb;
275+
}
276+
277+
this.isPredicting = true;
278+
const predictedClass = tf.tidy(() => {
279+
const processedImg = imgToTensor(imgToPredict);
280+
const activation = this.mobilenetFeatures.predict(processedImg);
281+
const predictions = this.customModel.predict(activation);
282+
return predictions.as1D();
283+
});
284+
const prediction = (await predictedClass.data());
285+
predictedClass.dispose();
286+
await tf.nextFrame();
287+
if (callback) {
288+
callback(prediction[0]);
289+
}
290+
} else {
291+
console.warn('Mobilenet Feature Extraction has not been set to be a regressor.');
292+
}
293+
}
294+
295+
// Static Method: get top k classes for mobilenet
296+
static async getTopKClasses(logits, topK, callback) {
297+
const values = await logits.data();
298+
const valuesAndIndices = [];
299+
for (let i = 0; i < values.length; i += 1) {
300+
valuesAndIndices.push({ value: values[i], index: i });
301+
}
302+
valuesAndIndices.sort((a, b) => b.value - a.value);
303+
const topkValues = new Float32Array(topK);
304+
305+
const topkIndices = new Int32Array(topK);
306+
for (let i = 0; i < topK; i += 1) {
307+
topkValues[i] = valuesAndIndices[i].value;
308+
topkIndices[i] = valuesAndIndices[i].index;
309+
}
310+
const topClassesAndProbs = [];
311+
for (let i = 0; i < topkIndices.length; i += 1) {
312+
topClassesAndProbs.push({
313+
className: IMAGENET_CLASSES[topkIndices[i]],
314+
probability: topkValues[i],
315+
});
316+
}
317+
318+
await tf.nextFrame();
319+
320+
if (callback) {
321+
callback(topClassesAndProbs);
322+
}
323+
return topClassesAndProbs;
13324
}
14325
}
15326

16-
export default FeatureExtractor;
327+
export default Mobilenet;

0 commit comments

Comments
 (0)