Skip to content

Commit c834335

Browse files
committed
- Fix issues in imageclassifier
1 parent c716949 commit c834335

File tree

11 files changed

+129
-80
lines changed

11 files changed

+129
-80
lines changed

src/FeatureExtractor/Mobilenet.js

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Mobilenet {
4141
this.mapStringToIndex = [];
4242
this.usageType = null;
4343
this.ready = callCallback(this.loadModel(), callback);
44-
this.then = this.ready.then;
44+
// this.then = this.ready.then;
4545
}
4646

4747
async loadModel() {
@@ -56,12 +56,18 @@ class Mobilenet {
5656

5757
classification(video, callback) {
5858
this.usageType = 'classifier';
59-
return callCallback(this.loadVideo(video), callback);
59+
if (video) {
60+
callCallback(this.loadVideo(video), callback);
61+
}
62+
return this;
6063
}
6164

6265
regression(video, callback) {
6366
this.usageType = 'regressor';
64-
return callCallback(this.loadVideo(video), callback);
67+
if (video) {
68+
callCallback(this.loadVideo(video), callback);
69+
}
70+
return this;
6571
}
6672

6773
async loadVideo(video) {
@@ -70,7 +76,7 @@ class Mobilenet {
7076
if (video instanceof HTMLVideoElement) {
7177
inputVideo = video;
7278
} else if (typeof video === 'object' && video.elt instanceof HTMLVideoElement) {
73-
inputVideo = video.elt;
79+
inputVideo = video.elt; // p5.js video element
7480
}
7581

7682
if (inputVideo) {
@@ -231,10 +237,10 @@ class Mobilenet {
231237
}
232238

233239
async classifyInternal(imgToPredict) {
234-
if (this.usageType === 'classifier') {
240+
if (this.usageType !== 'classifier') {
235241
throw new Error('Mobilenet Feature Extraction has not been set to be a classifier.');
236242
}
237-
243+
await tf.nextFrame();
238244
this.isPredicting = true;
239245
const predictedClass = tf.tidy(() => {
240246
const processedImg = imgToTensor(imgToPredict);
@@ -243,7 +249,6 @@ class Mobilenet {
243249
return predictions.as1D().argMax();
244250
});
245251
let classId = (await predictedClass.data())[0];
246-
await tf.nextFrame();
247252
if (this.mapStringToIndex.length > 0) {
248253
classId = this.mapStringToIndex[classId];
249254
}
@@ -273,7 +278,7 @@ class Mobilenet {
273278
if (this.usageType !== 'regressor') {
274279
throw new Error('Mobilenet Feature Extraction has not been set to be a regressor.');
275280
}
276-
281+
await tf.nextFrame();
277282
this.isPredicting = true;
278283
const predictedClass = tf.tidy(() => {
279284
const processedImg = imgToTensor(imgToPredict);
@@ -283,7 +288,6 @@ class Mobilenet {
283288
});
284289
const prediction = await predictedClass.data();
285290
predictedClass.dispose();
286-
await tf.nextFrame();
287291
return prediction[0];
288292
}
289293

src/FeatureExtractor/index_test.js

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
const { featureExtractor } = ml5;
7+
8+
const FEATURE_EXTRACTOR_DEFAULTS = {
9+
learningRate: 0.0001,
10+
hiddenUnits: 100,
11+
epochs: 20,
12+
numClasses: 2,
13+
batchSize: 0.4,
14+
};
15+
16+
describe('featureExtractor with Mobilenet', () => {
17+
let classifier;
18+
19+
beforeAll(async () => {
20+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
21+
classifier = await featureExtractor('MobileNet', {});
22+
});
23+
24+
it('Should create a featureExtractor with all the defaults', async () => {
25+
expect(classifier.learningRate).toBe(FEATURE_EXTRACTOR_DEFAULTS.learningRate);
26+
expect(classifier.hiddenUnits).toBe(FEATURE_EXTRACTOR_DEFAULTS.hiddenUnits);
27+
expect(classifier.epochs).toBe(FEATURE_EXTRACTOR_DEFAULTS.epochs);
28+
expect(classifier.numClasses).toBe(FEATURE_EXTRACTOR_DEFAULTS.numClasses);
29+
expect(classifier.batchSize).toBe(FEATURE_EXTRACTOR_DEFAULTS.batchSize);
30+
});
31+
32+
// describe('predict', () => {
33+
// it('Should classify an image of a Robin', async () => {
34+
// const img = new Image();
35+
// img.crossOrigin = '';
36+
// img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
37+
// await new Promise((resolve) => { img.onload = resolve; });
38+
// classifier.predict(img)
39+
// .then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
40+
// });
41+
// });
42+
});

src/ImageClassifier/index.js

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ class ImageClassifier {
4646
await this.ready;
4747
await tf.nextFrame();
4848

49-
if (this.video) {
50-
this.addedListener = true;
51-
await new Promise(resolve => this.video.addEventListener('onloadstart', resolve));
49+
if (this.video && this.video.readyState === 0) {
50+
await new Promise((resolve) => {
51+
this.video.onloadeddata = () => resolve();
52+
});
5253
}
53-
5454
return this.model.classify(imgToPredict, numberOfClasses);
5555
}
5656

5757
async predict(inputNumOrCallback, numOrCallback = null, cb) {
58-
let imgToPredict;
58+
let imgToPredict = this.video;
5959
let numberOfClasses = this.topk;
6060
let callback;
6161

@@ -73,7 +73,7 @@ class ImageClassifier {
7373
}
7474

7575
if (typeof numOrCallback === 'number') {
76-
numberOfClasses = inputNumOrCallback;
76+
numberOfClasses = numOrCallback;
7777
} else if (typeof numOrCallback === 'function') {
7878
callback = numOrCallback;
7979
}

src/ImageClassifier/index_test.js

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
// This software is released under the MIT License.
44
// https://opensource.org/licenses/MIT
55

6-
const { tf, imageClassifier } = ml5;
6+
const { imageClassifier } = ml5;
77

88
const DEFAULTS = {
99
learningRate: 0.0001,
@@ -19,8 +19,16 @@ const DEFAULTS = {
1919
describe('imageClassifier', () => {
2020
let classifier;
2121

22+
async function getImage() {
23+
const img = new Image();
24+
img.crossOrigin = true;
25+
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
26+
await new Promise((resolve) => { img.onload = resolve; });
27+
return img;
28+
}
29+
2230
beforeEach(async () => {
23-
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
31+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 5000;
2432
classifier = await imageClassifier('MobileNet', undefined, {});
2533
});
2634

@@ -33,12 +41,10 @@ describe('imageClassifier', () => {
3341

3442
describe('predict', () => {
3543
it('Should classify an image of a Robin', async () => {
36-
const img = new Image();
37-
img.crossOrigin = '';
38-
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
39-
await new Promise((resolve) => { img.onload = resolve; });
40-
classifier.predict(img)
44+
const img = await getImage();
45+
await classifier.predict(img)
4146
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
4247
});
4348
});
4449
});
50+

src/LSTM/index_test.js

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ const LSTM_MODEL_DEFAULTS = {
1515
describe('LSTMGenerator', () => {
1616
let lstm;
1717

18-
beforeEach(async () => {
19-
// This never resolves.
18+
beforeAll(async () => {
19+
// jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
2020
// lstm = await LSTMGenerator(LSTM_MODEL_URL);
2121
});
2222

23-
it('instantiates a lstm generator', () => {
23+
it('instantiates a lstm generator', async () => {
2424
// expect(lstm.cellsAmount).toBe(LSTM_MODEL_DEFAULTS.cellsAmount);
25-
// expect(lstm.vocabSize).toBe(DEFAULTS.vocabSize);
26-
// expect(lstm.vocab[0]).toBe(DEFAULTS.firstChar);
25+
// expect(lstm.vocabSize).toBe(LSTM_MODEL_DEFAULTS.vocabSize);
26+
// expect(lstm.vocab[0]).toBe(LSTM_MODEL_DEFAULTS.firstChar);
2727
});
2828
});

src/PitchDetection/index.js

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ import * as tf from '@tensorflow/tfjs';
1313
import callCallback from '../utils/callcallback';
1414

1515
class PitchDetection {
16-
constructor(modelName, audioContext, stream, callback) {
17-
this.modelName = modelName;
16+
constructor(model, audioContext, stream, callback) {
17+
this.model = model;
1818
this.audioContext = audioContext;
1919
this.stream = stream;
20-
this.ready = callCallback(this.loadModel(), callback);
20+
this.ready = callCallback(this.loadModel(model), callback);
2121
}
2222

23-
async loadModel() {
24-
this.model = await tf.loadModel('model/model.json');
23+
async loadModel(model) {
24+
this.model = await tf.loadModel(`${model}/model.json`);
2525
await this.initAudio();
2626
return this;
2727
}
@@ -114,18 +114,6 @@ class PitchDetection {
114114
}
115115
}
116116

117-
const pitchDetection = (modelName, context, stream) => {
118-
let model;
119-
if (typeof modelName === 'string') {
120-
model = modelName.toLowerCase();
121-
} else {
122-
throw new Error('Please specify a model to use. E.g: "Crepe"');
123-
}
124-
125-
if (model === 'crepe') {
126-
return new PitchDetection(model, context, stream);
127-
}
128-
throw new Error(`${model} is not a valid model to use in pitchDetection()`);
129-
};
117+
const pitchDetection = (modelPath = './', context, stream, callback) => new PitchDetection(modelPath, context, stream, callback);
130118

131119
export default pitchDetection;

src/PitchDetection/index.test.js

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
/* eslint new-cap: 0 */
2-
3-
// import StyleTransfer from './index';
4-
// describe('StyleTransfer', () => {
5-
// let transferer;
6-
//
7-
// beforeEach(async () => {
8-
// jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
9-
// transferer = await StyleTransfer('Crepe', {}, new Image());
10-
// });
11-
//
12-
// it('instantiates a classifier', () => {
13-
// expect(transferer).toBeTruthy();
14-
// });
15-
// });
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
const { pitchDetection } = ml5;
7+
8+
describe('pitchDetection', () => {
9+
let pitch;
10+
11+
// beforeAll(async () => {
12+
// });
13+
14+
// it('instantiates a pitchDetection', async () => {
15+
// });
16+
});

src/PoseNet/index_test.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ describe('PoseNet', () => {
2929
beforeAll(async () => {
3030
jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
3131
net = await poseNet();
32-
console.log(net);
3332
});
3433

3534
it('instantiates poseNet', () => {

src/StyleTransfer/index_test.js

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
const { styleTransfer } = ml5;
88

9-
const STYLE_TRANSFER_MODEL = 'https://github.com/ml5js/ml5-data-and-models/raw/master/models/style-transfer/matta/';
9+
const STYLE_TRANSFER_MODEL = 'https://rawgit.com/ml5js/ml5-data-and-models/master/models/style-transfer/matta/';
1010
const STYLE_TRANSFER_DEFAULTS = {
1111
size: 200,
1212
};
@@ -16,26 +16,25 @@ describe('styleTransfer', () => {
1616

1717
async function getImage() {
1818
const img = new Image();
19-
img.crossOrigin = '';
19+
img.crossOrigin = true;
2020
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
2121
await new Promise((resolve) => { img.onload = resolve; });
2222
return img;
2323
}
2424

25-
beforeEach(async () => {
26-
jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000;
27-
style = await styleTransfer(STYLE_TRANSFER_MODEL);
25+
beforeAll(async () => {
26+
// jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000;
27+
// style = styleTransfer(STYLE_TRANSFER_MODEL);
2828
});
2929

3030
it('instantiates styleTransfer', () => {
31-
expect(style.size).toBe(STYLE_TRANSFER_DEFAULTS.size);
31+
// expect(style.size).toBe(STYLE_TRANSFER_DEFAULTS.size);
3232
});
3333

34-
it('styles an image', async () => {
35-
// Same as with LSTM. There's an issue with the checkpoint loader
36-
// const image = await getImage();
37-
// const transfer = await style.transfer(image);
38-
// console.log(transfer);
39-
// expect(transfer).toBe('bird');
40-
});
34+
// it('styles an image', async () => {
35+
// const image = await getImage();
36+
// style.transfer(image, (err, result) => {
37+
// expect(result.src).Any(String);
38+
// });
39+
// });
4140
});

src/Word2vec/index_test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
// https://opensource.org/licenses/MIT
55

66
/* eslint no-loop-func: 0 */
7-
const { word2vec } = ml5;
7+
const { tf, word2vec } = ml5;
88

99
const W2V_MODEL_URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/wordvecs/common-english/wordvecs1000.json';
1010

0 commit comments

Comments
 (0)