Skip to content

Commit 249df9e

Browse files
authored
Add tests to CharRNN (#307) (#320)
* add tests to CharRNN * test(CharRNN): add tests to CharRNN added descriptive tests to ensure CharRNN behaves like its example * remove dist
1 parent ba5be79 commit 249df9e

File tree

1 file changed

+45
-13
lines changed

1 file changed

+45
-13
lines changed

src/CharRNN/index_test.js

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,58 @@
33
// This software is released under the MIT License.
44
// https://opensource.org/licenses/MIT
55

6-
const { LSTMGenerator } = ml5;
6+
const { charRNN } = ml5;
77

8-
const LSTM_MODEL_URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-models/master/models/lstm/woolf/';
9-
const LSTM_MODEL_DEFAULTS = {
8+
const RNN_MODEL_URL = 'https://raw.githubusercontent.com/ml5js/ml5-data-and-models/master/models/lstm/woolf';
9+
10+
const RNN_MODEL_DEFAULTS = {
1011
cellsAmount: 2,
11-
vocabSize: 90,
12-
firstChar: 61,
12+
vocabSize: 223
1313
};
1414

15-
describe('LSTMGenerator', () => {
16-
let lstm;
15+
const RNN_DEFAULTS = {
16+
seed: 'a',
17+
length: 20,
18+
temperature: 0.5,
19+
stateful: false
20+
}
21+
22+
const RNN_OPTIONS = {
23+
seed: 'the meaning of pizza is: ',
24+
length: 500,
25+
temperature: 0.7
26+
}
27+
28+
describe('charRnn', () => {
29+
let rnn;
1730

1831
beforeAll(async () => {
19-
// jasmine.DEFAULT_TIMEOUT_INTERVAL = 10000;
20-
// lstm = await LSTMGenerator(LSTM_MODEL_URL);
32+
jasmine.DEFAULT_TIMEOUT_INTERVAL = 30000; //set extra long interval due to issues with CharRNN generation time
33+
rnn = await charRNN(RNN_MODEL_URL, undefined);
2134
});
2235

23-
it('instantiates a lstm generator', async () => {
24-
// expect(lstm.cellsAmount).toBe(LSTM_MODEL_DEFAULTS.cellsAmount);
25-
// expect(lstm.vocabSize).toBe(LSTM_MODEL_DEFAULTS.vocabSize);
26-
// expect(lstm.vocab[0]).toBe(LSTM_MODEL_DEFAULTS.firstChar);
36+
it('instantiates an rnn with all the defaults', async () => {
37+
expect(rnn.ready).toBeTruthy();
38+
expect(rnn.defaults.seed).toBe(RNN_DEFAULTS.seed);
39+
expect(rnn.defaults.length).toBe(RNN_DEFAULTS.length);
40+
expect(rnn.defaults.temperature).toBe(RNN_DEFAULTS.temperature);
41+
expect(rnn.defaults.stateful).toBe(RNN_DEFAULTS.stateful);
42+
});
43+
44+
// it('loads the model with all the defaults', async () => {
45+
// expect(rnn.cellsAmount).toBe(RNN_MODEL_DEFAULTS.cellsAmount);
46+
// expect(rnn.vocabSize).toBe(RNN_MODEL_DEFAULTS.vocabSize);
47+
// });
48+
49+
describe('generate', () => {
50+
it('Should generate content that follows default options if given an empty object', async() => {
51+
const result = await rnn.generate({});
52+
expect(result.sample.length).toBe(20);
53+
});
54+
55+
it('generates content that follows the set options', async() => {
56+
const result = await rnn.generate(RNN_OPTIONS);
57+
expect(result.sample.length).toBe(500);
58+
});
2759
});
2860
});

0 commit comments

Comments
 (0)