Skip to content

Commit cccccbe

Browse files
authored
Stateful LSTM (#234)
* allow passing in of lstm state c and h. return more lstm info (c, h, probabilities) * fix LSTM Stateful ness and other tweaks * LSTM probabilities class property * remove weird ind<100 check * remove unnessecary initCells * add LSTM methods state getter, setter and reset * rename to charRNN and update stateful to match discussed api * add callback to feed * fix bug in predict * adding new probabilites object * new build
1 parent 3800120 commit cccccbe

File tree

7 files changed

+326
-196
lines changed

7 files changed

+326
-196
lines changed

dist/ml5.min.js

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

dist/ml5.min.js.map

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package-lock.json

Lines changed: 67 additions & 26 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/CharRNN/index.js

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
// Copyright (c) 2018 ml5
2+
//
3+
// This software is released under the MIT License.
4+
// https://opensource.org/licenses/MIT
5+
6+
/* eslint prefer-destructuring: ["error", {AssignmentExpression: {array: false}}] */
7+
/* eslint no-await-in-loop: "off" */
8+
/*
9+
A LSTM Generator: Run inference mode for a pre-trained LSTM.
10+
*/
11+
12+
import * as tf from '@tensorflow/tfjs';
13+
import sampleFromDistribution from './../utils/sample';
14+
import CheckpointLoader from '../utils/checkpointLoader';
15+
import callCallback from '../utils/callcallback';
16+
17+
const regexCell = /cell_[0-9]|lstm_[0-9]/gi;
18+
const regexWeights = /weights|weight|kernel|kernels|w/gi;
19+
const regexFullyConnected = /softmax/gi;
20+
21+
class CharRNN {
22+
constructor(modelPath, callback) {
23+
this.ready = false;
24+
this.model = {};
25+
this.cellsAmount = 0;
26+
this.cells = [];
27+
this.zeroState = { c: [], h: [] };
28+
this.state = { c: [], h: [] };
29+
this.vocab = {};
30+
this.vocabSize = 0;
31+
this.probabilities = [];
32+
this.defaults = {
33+
seed: 'a', // TODO: use no seed by default
34+
length: 20,
35+
temperature: 0.5,
36+
stateful: false,
37+
};
38+
39+
this.ready = callCallback(this.loadCheckpoints(modelPath), callback);
40+
// this.then = this.ready.then.bind(this.ready);
41+
}
42+
43+
resetState() {
44+
this.state = this.zeroState;
45+
}
46+
47+
setState(state) {
48+
this.state = state;
49+
}
50+
51+
getState() {
52+
return this.state;
53+
}
54+
55+
async loadCheckpoints(path) {
56+
const reader = new CheckpointLoader(path);
57+
const vars = await reader.getAllVariables();
58+
Object.keys(vars).forEach((key) => {
59+
if (key.match(regexCell)) {
60+
if (key.match(regexWeights)) {
61+
this.model[`Kernel_${key.match(/[0-9]/)[0]}`] = vars[key];
62+
this.cellsAmount += 1;
63+
} else {
64+
this.model[`Bias_${key.match(/[0-9]/)[0]}`] = vars[key];
65+
}
66+
} else if (key.match(regexFullyConnected)) {
67+
if (key.match(regexWeights)) {
68+
this.model.fullyConnectedWeights = vars[key];
69+
} else {
70+
this.model.fullyConnectedBiases = vars[key];
71+
}
72+
} else {
73+
this.model[key] = vars[key];
74+
}
75+
});
76+
await this.loadVocab(path);
77+
await this.initCells();
78+
return this;
79+
}
80+
81+
async loadVocab(path) {
82+
const json = await fetch(`${path}/vocab.json`)
83+
.then(response => response.json())
84+
.catch(err => console.error(err));
85+
this.vocab = json;
86+
this.vocabSize = Object.keys(json).length;
87+
}
88+
89+
async initCells() {
90+
this.cells = [];
91+
this.zeroState = { c: [], h: [] };
92+
const forgetBias = tf.tensor(1.0);
93+
94+
const lstm = (i) => {
95+
const cell = (DATA, C, H) =>
96+
tf.basicLSTMCell(forgetBias, this.model[`Kernel_${i}`], this.model[`Bias_${i}`], DATA, C, H);
97+
return cell;
98+
};
99+
100+
for (let i = 0; i < this.cellsAmount; i += 1) {
101+
this.zeroState.c.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4]));
102+
this.zeroState.h.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4]));
103+
this.cells.push(lstm(i));
104+
}
105+
106+
this.state = this.zeroState;
107+
}
108+
109+
async generateInternal(options) {
110+
await this.ready;
111+
const seed = options.seed || this.defaults.seed;
112+
const length = +options.length || this.defaults.length;
113+
const temperature = +options.temperature || this.defaults.temperature;
114+
const stateful = options.stateful || this.defaults.stateful;
115+
if (!stateful) {
116+
this.state = this.zeroState;
117+
}
118+
119+
const results = [];
120+
const userInput = Array.from(seed);
121+
const encodedInput = [];
122+
123+
userInput.forEach((char) => {
124+
encodedInput.push(this.vocab[char]);
125+
});
126+
127+
let input = encodedInput[0];
128+
let probabilitiesNormalized = []; // will contain final probabilities (normalized)
129+
130+
for (let i = 0; i < userInput.length + length + -1; i += 1) {
131+
const onehotBuffer = tf.buffer([1, this.vocabSize]);
132+
onehotBuffer.set(1.0, 0, input);
133+
const onehot = onehotBuffer.toTensor();
134+
let output;
135+
if (this.model.embedding) {
136+
const embedded = tf.matMul(onehot, this.model.embedding);
137+
output = tf.multiRNNCell(this.cells, embedded, this.state.c, this.state.h);
138+
} else {
139+
output = tf.multiRNNCell(this.cells, onehot, this.state.c, this.state.h);
140+
}
141+
142+
this.state.c = output[0];
143+
this.state.h = output[1];
144+
145+
const outputH = this.state.h[1];
146+
const weightedResult = tf.matMul(outputH, this.model.fullyConnectedWeights);
147+
const logits = tf.add(weightedResult, this.model.fullyConnectedBiases);
148+
const divided = tf.div(logits, tf.tensor(temperature));
149+
const probabilities = tf.exp(divided);
150+
probabilitiesNormalized = await tf.div(
151+
probabilities,
152+
tf.sum(probabilities),
153+
).data();
154+
155+
if (i < userInput.length - 1) {
156+
input = encodedInput[i + 1];
157+
} else {
158+
input = sampleFromDistribution(probabilitiesNormalized);
159+
results.push(input);
160+
}
161+
}
162+
163+
let generated = '';
164+
results.forEach((char) => {
165+
const mapped = Object.keys(this.vocab).find(key => this.vocab[key] === char);
166+
if (mapped) {
167+
generated += mapped;
168+
}
169+
});
170+
this.probabilities = probabilitiesNormalized;
171+
return {
172+
sample: generated,
173+
state: this.state,
174+
};
175+
}
176+
177+
reset() {
178+
this.state = this.zeroState;
179+
}
180+
181+
// stateless
182+
async generate(options, callback) {
183+
this.reset();
184+
return callCallback(this.generateInternal(options), callback);
185+
}
186+
187+
// stateful
188+
async predict(temp, callback) {
189+
let probabilitiesNormalized = [];
190+
const temperature = temp > 0 ? temp : 0.1;
191+
const outputH = this.state.h[1];
192+
const weightedResult = tf.matMul(outputH, this.model.fullyConnectedWeights);
193+
const logits = tf.add(weightedResult, this.model.fullyConnectedBiases);
194+
const divided = tf.div(logits, tf.tensor(temperature));
195+
const probabilities = tf.exp(divided);
196+
probabilitiesNormalized = await tf.div(
197+
probabilities,
198+
tf.sum(probabilities),
199+
).data();
200+
201+
const sample = sampleFromDistribution(probabilitiesNormalized);
202+
const result = Object.keys(this.vocab).find(key => this.vocab[key] === sample);
203+
this.probabilities = probabilitiesNormalized;
204+
if (callback) {
205+
callback(result);
206+
}
207+
/* eslint max-len: ["error", { "code": 180 }] */
208+
const pm = Object.keys(this.vocab).map(c => ({ char: c, probability: this.probabilities[this.vocab[c]] }));
209+
return {
210+
sample: result,
211+
probabilities: pm,
212+
};
213+
}
214+
215+
async feed(inputSeed, callback) {
216+
await this.ready;
217+
const seed = Array.from(inputSeed);
218+
const encodedInput = [];
219+
220+
seed.forEach((char) => {
221+
encodedInput.push(this.vocab[char]);
222+
});
223+
224+
let input = encodedInput[0];
225+
for (let i = 0; i < seed.length; i += 1) {
226+
const onehotBuffer = tf.buffer([1, this.vocabSize]);
227+
onehotBuffer.set(1.0, 0, input);
228+
const onehot = onehotBuffer.toTensor();
229+
let output;
230+
if (this.model.embedding) {
231+
const embedded = tf.matMul(onehot, this.model.embedding);
232+
output = tf.multiRNNCell(this.cells, embedded, this.state.c, this.state.h);
233+
} else {
234+
output = tf.multiRNNCell(this.cells, onehot, this.state.c, this.state.h);
235+
}
236+
this.state.c = output[0];
237+
this.state.h = output[1];
238+
input = encodedInput[i];
239+
}
240+
if (callback) {
241+
callback();
242+
}
243+
}
244+
}
245+
246+
const charRNN = (modelPath = './', callback) => new CharRNN(modelPath, callback);
247+
248+
export default charRNN;
File renamed without changes.

0 commit comments

Comments
 (0)