Skip to content

Commit 103afd7

Browse files
committed
- Add test for YOLO
- Add test for Imageclassifier - Resize img if necessary in utils
1 parent 5660464 commit 103afd7

File tree

8 files changed

+83
-64
lines changed

8 files changed

+83
-64
lines changed

src/ImageClassifier/index_test.js

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ const DEFAULTS = {
1313
version: 1,
1414
};
1515

16-
describe('Create an image classifier', () => {
16+
describe('imageClassifier', () => {
1717
let classifier;
1818

1919
beforeEach(async () => {
@@ -28,12 +28,14 @@ describe('Create an image classifier', () => {
2828
expect(classifier.ready).toBeTruthy();
2929
});
3030

31-
it('Should classify an robin, American robin, Turdus migratorius', async () => {
32-
const img = new Image();
33-
img.crossOrigin = '';
34-
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
35-
await new Promise((resolve) => { img.onload = resolve; });
36-
classifier.predict(img)
37-
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
31+
describe('predict', () => {
32+
it('Should classify an image of a Robin', async () => {
33+
const img = new Image();
34+
img.crossOrigin = '';
35+
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
36+
await new Promise((resolve) => { img.onload = resolve; });
37+
classifier.predict(img)
38+
.then(results => expect(results[0].className).toBe('robin, American robin, Turdus migratorius'));
39+
});
3840
});
3941
});

src/LSTM/index.js

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ class LSTM {
3030
length: 20,
3131
temperature: 0.5,
3232
};
33+
3334
this.ready = callCallback(this.loadCheckpoints(model), callback);
34-
this.then = this.ready.then.bind(this.ready);
35+
// this.then = this.ready.then.bind(this.ready);
3536
}
3637

3738
async loadCheckpoints(path) {
@@ -59,9 +60,10 @@ class LSTM {
5960
return this;
6061
}
6162

62-
async loadVocab(file) {
63-
const json = await fetch(`${file}/vocab.json`)
64-
.then(response => response.json());
63+
async loadVocab(path) {
64+
const json = await fetch(`${path}/vocab.json`)
65+
.then(response => response.json())
66+
.catch(err => console.error(err));
6567
this.vocab = json;
6668
this.vocabSize = Object.keys(json).length;
6769
}

src/LSTM/index_test.js

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
1-
/* eslint new-cap: 0 */
1+
const { LSTMGenerator } = ml5;
22

3-
// import LSTMGenerator from './index';
3+
const LSTM_MODEL_URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-models/master/models/lstm/woolf/';
4+
const LSTM_MODEL_DEFAULTS = {
5+
cellsAmount: 2,
6+
vocabSize: 90,
7+
firstChar: 61,
8+
};
49

5-
describe('LSTM', () => {
6-
// let generator;
10+
describe('LSTMGenerator', () => {
11+
let lstm;
712

8-
// beforeEach(async () => {
9-
// jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
10-
// generator = await LSTMGenerator('https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/lstm/dubois/');
11-
// });
13+
beforeEach(async () => {
14+
// This never resolves.
15+
// lstm = await LSTMGenerator(LSTM_MODEL_URL);
16+
});
1217

13-
// it('instantiates a generator', () => {
14-
// expect(generator).toBeTruthy();
15-
// });
16-
17-
// Fails with 'must be a Tensor' error that's particular to this test suite.
18-
// it('generates some text', async () => {
19-
// expect(await generator.generate('Hi there')).toBeTruthy();
20-
// });
18+
it('instantiates a lstm generator', () => {
19+
// expect(lstm.cellsAmount).toBe(LSTM_MODEL_DEFAULTS.cellsAmount);
20+
// expect(lstm.vocabSize).toBe(DEFAULTS.vocabSize);
21+
// expect(lstm.vocab[0]).toBe(DEFAULTS.firstChar);
22+
});
2123
});

src/Word2vec/index_test.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/* eslint no-loop-func: 0 */
22
const { word2vec } = ml5;
33

4-
const URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/wordvecs/common-english/wordvecs1000.json';
4+
const W2V_MODEL_URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-training/master/models/wordvecs/common-english/wordvecs1000.json';
55

66
describe('word2vec', () => {
77
let word2vecInstance;
@@ -10,7 +10,7 @@ describe('word2vec', () => {
1010
beforeAll((done) => {
1111
jasmine.DEFAULT_TIMEOUT_INTERVAL = 5000;
1212
numTensorsBeforeAll = tf.memory().numTensors;
13-
word2vecInstance = word2vec(URL, done);
13+
word2vecInstance = word2vec(W2V_MODEL_URL, done);
1414
});
1515

1616
afterAll(() => {

src/YOLO/index.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class YOLOBase extends Video {
4444
this.modelReady = false;
4545
this.isPredicting = false;
4646
this.ready = callCallback(this.loadModel(), callback);
47-
this.then = this.ready.then;
47+
// this.then = this.ready.then;
4848
}
4949

5050
async loadModel() {
@@ -75,7 +75,7 @@ class YOLOBase extends Video {
7575
await tf.nextFrame();
7676
this.isPredicting = true;
7777
const [allBoxes, boxConfidence, boxClassProbs] = tf.tidy(() => {
78-
const input = imgToTensor(imgToPredict);
78+
const input = imgToTensor(imgToPredict, [imageSize, imageSize]);
7979
const activation = this.model.predict(input);
8080
const [boxXY, boxWH, bConfidence, bClassProbs] = head(activation, ANCHORS, 80);
8181
const aBoxes = boxesToCorners(boxXY, boxWH);

src/YOLO/index.test.js

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/YOLO/index_test.js

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
/* eslint new-cap: 0 */
2+
3+
const { YOLO } = ml5;
4+
5+
const YOLO_DEFAULTS = {
6+
IOUThreshold: 0.4,
7+
classProbThreshold: 0.4,
8+
filterBoxesThreshold: 0.01,
9+
size: 416,
10+
};
11+
12+
describe('YOLO', () => {
13+
let yolo;
14+
15+
async function getRobin() {
16+
const img = new Image();
17+
img.crossOrigin = '';
18+
img.src = 'https://ml5js.org/docs/assets/img/bird.jpg';
19+
await new Promise((resolve) => { img.onload = resolve; });
20+
return img;
21+
}
22+
23+
beforeEach(async () => {
24+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 100000;
25+
yolo = await YOLO();
26+
});
27+
28+
it('instantiates the YOLO classifier with defaults', () => {
29+
expect(yolo.IOUThreshold).toBe(YOLO_DEFAULTS.IOUThreshold);
30+
expect(yolo.classProbThreshold).toBe(YOLO_DEFAULTS.classProbThreshold);
31+
expect(yolo.filterBoxesThreshold).toBe(YOLO_DEFAULTS.filterBoxesThreshold);
32+
expect(yolo.size).toBe(YOLO_DEFAULTS.size);
33+
});
34+
35+
it('detects a robin', async () => {
36+
const robin = await getRobin();
37+
const detection = await yolo.detect(robin);
38+
expect(detection[0].className).toBe('bird');
39+
});
40+
});

src/utils/imageUtilities.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ const cropImage = (img) => {
6262
};
6363

6464
// Static Method: image to tf tensor
65-
function imgToTensor(input) {
65+
function imgToTensor(input, size = null) {
6666
return tf.tidy(() => {
67-
const img = tf.fromPixels(input);
67+
let img = tf.fromPixels(input);
68+
if (size) {
69+
img = tf.image.resizeBilinear(img, size);
70+
}
6871
const croppedImage = cropImage(img);
6972
const batchedImage = croppedImage.expandDims(0);
7073
return batchedImage.toFloat().div(tf.scalar(127)).sub(tf.scalar(1));

0 commit comments

Comments
 (0)