Skip to content

Commit 82843f8

Browse files
authored
Merge pull request #173 from meiamsome/testing/word2vec
Word2Vec tests
2 parents 8431d99 + d921d83 commit 82843f8

File tree

5 files changed

+172
-57
lines changed

5 files changed

+172
-57
lines changed

karma.conf.js

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,42 @@ module.exports = (config) => {
22
config.set({
33
frameworks: ['jasmine'],
44
files: [
5+
'src/index.js',
56
'src/**/*_test.js',
67
],
78
preprocessors: {
8-
'src/**/*_test.js': ['webpack'],
9+
'src/index.js': ['webpack'],
910
},
1011
webpack: {
11-
// karma watches the test entry points
12-
// (you don't need to specify the entry option)
13-
// webpack watches dependencies
14-
15-
// webpack configuration
12+
// TODO: This is duplication of the webpack.common.babel.js file, but they
13+
// use different import syntaxes so it's not easy to just require it here.
14+
// Maybe this could be put into a JSON file, but the include in the module
15+
// rules is dynamic.
16+
entry: ['babel-polyfill', './src/index.js'],
17+
output: {
18+
libraryTarget: 'umd',
19+
filename: 'ml5.js',
20+
library: 'ml5',
21+
},
22+
module: {
23+
rules: [
24+
{
25+
enforce: 'pre',
26+
test: /\.js$/,
27+
exclude: /node_modules/,
28+
loader: 'eslint-loader',
29+
},
30+
{
31+
test: /\.js$/,
32+
loader: 'babel-loader',
33+
include: require('path').resolve(__dirname, 'src'),
34+
},
35+
],
36+
},
37+
// Don't minify the webpack build for better stack traces
38+
optimization: {
39+
minimize: false,
40+
},
1641
},
1742
webpackMiddleware: {
1843
noInfo: true,

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"start": "webpack-dev-server --open --config webpack.dev.babel.js",
1414
"build": "webpack --config webpack.prod.babel.js",
1515
"test": "./node_modules/karma/bin/karma start karma.conf.js",
16+
"test:single": "./node_modules/karma/bin/karma start karma.conf.js --single-run",
1617
"test-travis": "./scripts/test-travis.sh"
1718
},
1819
"repository": {

src/ImageClassifier/index_test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/* eslint new-cap: 0 */
22

3-
import * as ImageClassifier from './index';
3+
const { ImageClassifier } = ml5;
44

55
const DEFAULTS = {
66
learningRate: 0.0001,

src/Word2vec/index.js

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,30 @@ class Word2Vec {
3434
});
3535
}
3636

37+
dispose() {
38+
Object.values(this.model).forEach(x => x.dispose());
39+
}
40+
3741
add(inputs, max = 1) {
38-
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
39-
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
42+
return tf.tidy(() => {
43+
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
44+
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
45+
});
4046
}
4147

4248
subtract(inputs, max = 1) {
43-
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
44-
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
49+
return tf.tidy(() => {
50+
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
51+
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
52+
});
4553
}
4654

4755
average(inputs, max = 1) {
48-
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
49-
const avg = tf.div(sum, tf.tensor(inputs.length));
50-
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
56+
return tf.tidy(() => {
57+
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
58+
const avg = tf.div(sum, tf.tensor(inputs.length));
59+
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
60+
});
5161
}
5262

5363
nearest(input, max = 10) {
@@ -64,34 +74,36 @@ class Word2Vec {
6474
}
6575

6676
static addOrSubtract(model, values, operation) {
67-
const vectors = [];
68-
const notFound = [];
69-
if (values.length < 2) {
70-
throw new Error('Invalid input, must be passed more than 1 value');
71-
}
72-
values.forEach((value) => {
73-
const vector = model[value];
74-
if (!vector) {
75-
notFound.push(value);
76-
} else {
77-
vectors.push(vector);
77+
return tf.tidy(() => {
78+
const vectors = [];
79+
const notFound = [];
80+
if (values.length < 2) {
81+
throw new Error('Invalid input, must be passed more than 1 value');
7882
}
79-
});
83+
values.forEach((value) => {
84+
const vector = model[value];
85+
if (!vector) {
86+
notFound.push(value);
87+
} else {
88+
vectors.push(vector);
89+
}
90+
});
8091

81-
if (notFound.length > 0) {
82-
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
83-
}
84-
let result = vectors[0];
85-
if (operation === 'ADD') {
86-
for (let i = 1; i < vectors.length; i += 1) {
87-
result = tf.add(result, vectors[i]);
92+
if (notFound.length > 0) {
93+
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
8894
}
89-
} else {
90-
for (let i = 1; i < vectors.length; i += 1) {
91-
result = tf.sub(result, vectors[i]);
95+
let result = vectors[0];
96+
if (operation === 'ADD') {
97+
for (let i = 1; i < vectors.length; i += 1) {
98+
result = tf.add(result, vectors[i]);
99+
}
100+
} else {
101+
for (let i = 1; i < vectors.length; i += 1) {
102+
result = tf.sub(result, vectors[i]);
103+
}
92104
}
93-
}
94-
return result;
105+
return result;
106+
});
95107
}
96108

97109
static nearest(model, input, start, max) {

src/Word2vec/index_test.js

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,96 @@
1-
// import Word2Vec from './index';
2-
3-
// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json';
4-
5-
// describe('initialize word2vec', () => {
6-
// let word2vec;
7-
// beforeAll((done) => {
8-
// // word2vec = new Word2Vec(URL);
9-
// done();
10-
// });
11-
12-
// // it('creates a new instance', (done) => {
13-
// // expect(word2vec).toEqual(jasmine.objectContaining({
14-
// // ready: true,
15-
// // modelSize: 1,
16-
// // }));
17-
// // done();
18-
// // });
19-
// });
1+
const { tf, word2vec } = ml5;
2+
3+
const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json';
4+
5+
describe('word2vec', () => {
6+
let word2vecInstance;
7+
let numTensorsBeforeAll;
8+
let numTensorsBeforeEach;
9+
beforeAll((done) => {
10+
numTensorsBeforeAll = tf.memory().numTensors;
11+
word2vecInstance = word2vec(URL, done);
12+
});
13+
14+
afterAll(() => {
15+
word2vecInstance.dispose();
16+
let numTensorsAfterAll = tf.memory().numTensors;
17+
if(numTensorsBeforeAll !== numTensorsAfterAll) {
18+
throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`);
19+
}
20+
});
21+
22+
beforeEach(() => {
23+
numTensorsBeforeEach = tf.memory().numTensors;
24+
});
25+
26+
afterEach(() => {
27+
let numTensorsAfterEach = tf.memory().numTensors;
28+
if(numTensorsBeforeEach !== numTensorsAfterEach) {
29+
throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`);
30+
}
31+
});
32+
33+
it('creates a new instance', () => {
34+
expect(word2vecInstance).toEqual(jasmine.objectContaining({
35+
ready: true,
36+
modelSize: 1,
37+
}));
38+
});
39+
40+
describe('getRandomWord', () => {
41+
it('returns a word', () => {
42+
let word = word2vecInstance.getRandomWord();
43+
expect(typeof word).toEqual('string');
44+
});
45+
});
46+
47+
describe('nearest', () => {
48+
it('returns a sorted array of nearest words', () => {
49+
for(let i = 0; i < 100; i++) {
50+
let word = word2vecInstance.getRandomWord();
51+
let nearest = word2vecInstance.nearest(word);
52+
let currentDistance = 0;
53+
for(let { word, distance: nextDistance } of nearest) {
54+
expect(typeof word).toEqual('string');
55+
expect(nextDistance).toBeGreaterThan(currentDistance);
56+
currentDistance = nextDistance;
57+
}
58+
}
59+
});
60+
61+
it('returns a list of the right length', () => {
62+
for(let i = 0; i < 100; i++) {
63+
let word = word2vecInstance.getRandomWord();
64+
let nearest = word2vecInstance.nearest(word, i);
65+
expect(nearest.length).toEqual(i);
66+
}
67+
});
68+
});
69+
70+
describe('add', () => {
71+
it('returns a value', () => {
72+
let word1 = word2vecInstance.getRandomWord();
73+
let word2 = word2vecInstance.getRandomWord();
74+
let sum = word2vecInstance.subtract([word1, word2]);
75+
expect(sum[0].distance).toBeGreaterThan(0);
76+
})
77+
});
78+
79+
describe('subtract', () => {
80+
it('returns a value', () => {
81+
let word1 = word2vecInstance.getRandomWord();
82+
let word2 = word2vecInstance.getRandomWord();
83+
let sum = word2vecInstance.subtract([word1, word2]);
84+
expect(sum[0].distance).toBeGreaterThan(0);
85+
})
86+
});
87+
88+
describe('average', () => {
89+
it('returns a value', () => {
90+
let word1 = word2vecInstance.getRandomWord();
91+
let word2 = word2vecInstance.getRandomWord();
92+
let average = word2vecInstance.average([word1, word2]);
93+
expect(average[0].distance).toBeGreaterThan(0);
94+
});
95+
});
96+
});

0 commit comments

Comments
 (0)