Skip to content

Commit f8b83e2

Browse files
committed
Add tests for Word2Vec
- validates that there are no leaked tensors - validates basic functionality of nearest - validates that add, subtract and average return things - Fixes memory leaks in add, subtract, average, and addOrSubtract functions - Adds a general dispose to the Word2Vec class
1 parent d8bb936 commit f8b83e2

File tree

2 files changed

+140
-50
lines changed

2 files changed

+140
-50
lines changed

src/Word2vec/index.js

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,31 @@ class Word2Vec {
3434
});
3535
}
3636

37+
dispose() {
38+
Object.values(this.model).forEach(x => x.dispose());
39+
}
40+
3741
add(inputs, max = 1) {
38-
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
39-
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
42+
return tf.tidy(() => {
43+
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
44+
console.log(sum);
45+
return Word2Vec.nearest(this.model, sum, inputs.length, inputs.length + max);
46+
});
4047
}
4148

4249
subtract(inputs, max = 1) {
43-
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
44-
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
50+
return tf.tidy(() => {
51+
const subtraction = Word2Vec.addOrSubtract(this.model, inputs, 'SUBTRACT');
52+
return Word2Vec.nearest(this.model, subtraction, inputs.length, inputs.length + max);
53+
});
4554
}
4655

4756
average(inputs, max = 1) {
48-
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
49-
const avg = tf.div(sum, tf.tensor(inputs.length));
50-
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
57+
return tf.tidy(() => {
58+
const sum = Word2Vec.addOrSubtract(this.model, inputs, 'ADD');
59+
const avg = tf.div(sum, tf.tensor(inputs.length));
60+
return Word2Vec.nearest(this.model, avg, inputs.length, inputs.length + max);
61+
});
5162
}
5263

5364
nearest(input, max = 10) {
@@ -64,34 +75,36 @@ class Word2Vec {
6475
}
6576

6677
static addOrSubtract(model, values, operation) {
67-
const vectors = [];
68-
const notFound = [];
69-
if (values.length < 2) {
70-
throw new Error('Invalid input, must be passed more than 1 value');
71-
}
72-
values.forEach((value) => {
73-
const vector = model[value];
74-
if (!vector) {
75-
notFound.push(value);
76-
} else {
77-
vectors.push(vector);
78+
return tf.tidy(() => {
79+
const vectors = [];
80+
const notFound = [];
81+
if (values.length < 2) {
82+
throw new Error('Invalid input, must be passed more than 1 value');
7883
}
79-
});
84+
values.forEach((value) => {
85+
const vector = model[value];
86+
if (!vector) {
87+
notFound.push(value);
88+
} else {
89+
vectors.push(vector);
90+
}
91+
});
8092

81-
if (notFound.length > 0) {
82-
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
83-
}
84-
let result = vectors[0];
85-
if (operation === 'ADD') {
86-
for (let i = 1; i < vectors.length; i += 1) {
87-
result = tf.add(result, vectors[i]);
93+
if (notFound.length > 0) {
94+
throw new Error(`Invalid input, vector not found for: ${notFound.toString()}`);
8895
}
89-
} else {
90-
for (let i = 1; i < vectors.length; i += 1) {
91-
result = tf.sub(result, vectors[i]);
96+
let result = vectors[0];
97+
if (operation === 'ADD') {
98+
for (let i = 1; i < vectors.length; i += 1) {
99+
result = tf.add(result, vectors[i]);
100+
}
101+
} else {
102+
for (let i = 1; i < vectors.length; i += 1) {
103+
result = tf.sub(result, vectors[i]);
104+
}
92105
}
93-
}
94-
return result;
106+
return result;
107+
});
95108
}
96109

97110
static nearest(model, input, start, max) {

src/Word2vec/index_test.js

Lines changed: 96 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,96 @@
1-
// import Word2Vec from './index';
2-
3-
// const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/07_Word2Vec/data/wordvecs1000.json';
4-
5-
// describe('initialize word2vec', () => {
6-
// let word2vec;
7-
// beforeAll((done) => {
8-
// // word2vec = new Word2Vec(URL);
9-
// done();
10-
// });
11-
12-
// // it('creates a new instance', (done) => {
13-
// // expect(word2vec).toEqual(jasmine.objectContaining({
14-
// // ready: true,
15-
// // modelSize: 1,
16-
// // }));
17-
// // done();
18-
// // });
19-
// });
1+
const { tf, word2vec } = ml5;
2+
3+
const URL = 'https://raw.githubusercontent.com/ml5js/ml5-examples/master/p5js/Word2Vec/data/wordvecs1000.json';
4+
5+
describe('word2vec', () => {
6+
let word2vecInstance;
7+
let numTensorsBeforeAll;
8+
let numTensorsBeforeEach;
9+
beforeAll((done) => {
10+
numTensorsBeforeAll = tf.memory().numTensors;
11+
word2vecInstance = word2vec(URL, done);
12+
});
13+
14+
afterAll(() => {
15+
word2vecInstance.dispose();
16+
let numTensorsAfterAll = tf.memory().numTensors;
17+
if(numTensorsBeforeAll !== numTensorsAfterAll) {
18+
throw new Error(`Leaking Tensors (${numTensorsAfterAll} vs ${numTensorsBeforeAll})`);
19+
}
20+
});
21+
22+
beforeEach(() => {
23+
numTensorsBeforeEach = tf.memory().numTensors;
24+
});
25+
26+
afterEach(() => {
27+
let numTensorsAfterEach = tf.memory().numTensors;
28+
if(numTensorsBeforeEach !== numTensorsAfterEach) {
29+
throw new Error(`Leaking Tensors (${numTensorsAfterEach} vs ${numTensorsBeforeEach})`);
30+
}
31+
});
32+
33+
it('creates a new instance', () => {
34+
expect(word2vecInstance).toEqual(jasmine.objectContaining({
35+
ready: true,
36+
modelSize: 1,
37+
}));
38+
});
39+
40+
describe('getRandomWord', () => {
41+
it('returns a word', () => {
42+
let word = word2vecInstance.getRandomWord();
43+
expect(typeof word).toEqual('string');
44+
});
45+
});
46+
47+
describe('nearest', () => {
48+
it('returns a sorted array of nearest words', () => {
49+
for(let i = 0; i < 100; i++) {
50+
let word = word2vecInstance.getRandomWord();
51+
let nearest = word2vecInstance.nearest(word);
52+
let currentDistance = 0;
53+
for(let { word, distance: nextDistance } of nearest) {
54+
expect(typeof word).toEqual('string');
55+
expect(nextDistance).toBeGreaterThan(currentDistance);
56+
currentDistance = nextDistance;
57+
}
58+
}
59+
});
60+
61+
it('returns a list of the right length', () => {
62+
for(let i = 0; i < 100; i++) {
63+
let word = word2vecInstance.getRandomWord();
64+
let nearest = word2vecInstance.nearest(word, i);
65+
expect(nearest.length).toEqual(i);
66+
}
67+
});
68+
});
69+
70+
describe('add', () => {
71+
it('returns a value', () => {
72+
let word1 = word2vecInstance.getRandomWord();
73+
let word2 = word2vecInstance.getRandomWord();
74+
let sum = word2vecInstance.subtract([word1, word2]);
75+
expect(sum[0].distance).toBeGreaterThan(0);
76+
})
77+
});
78+
79+
describe('subtract', () => {
80+
it('returns a value', () => {
81+
let word1 = word2vecInstance.getRandomWord();
82+
let word2 = word2vecInstance.getRandomWord();
83+
let sum = word2vecInstance.subtract([word1, word2]);
84+
expect(sum[0].distance).toBeGreaterThan(0);
85+
})
86+
});
87+
88+
describe('average', () => {
89+
it('returns a value', () => {
90+
let word1 = word2vecInstance.getRandomWord();
91+
let word2 = word2vecInstance.getRandomWord();
92+
let average = word2vecInstance.average([word1, word2]);
93+
expect(average[0].distance).toBeGreaterThan(0);
94+
});
95+
});
96+
});

0 commit comments

Comments
 (0)