Skip to content

Commit e200223

Browse files
committed
switch LSTM to previous version: works faster and better
1 parent a14609a commit e200223

File tree

1 file changed

+128
-74
lines changed

1 file changed

+128
-74
lines changed

src/LSTM/index.js

Lines changed: 128 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -7,95 +7,149 @@
77
/* eslint no-await-in-loop: "off" */
88
/*
99
A LSTM Generator: Run inference mode for a pre-trained LSTM.
10-
Heavily derived from https://github.com/reiinakano/tfjs-lstm-text-generation/
1110
*/
1211

1312
import * as tf from '@tensorflow/tfjs';
13+
import sampleFromDistribution from './../utils/sample';
14+
import CheckpointLoader from '../utils/checkpointLoader';
1415

15-
const DEFAULTS = {
16-
inputLength: 40,
17-
length: 20,
18-
temperature: 0.5,
19-
};
16+
const regexCell = /cell_[0-9]|lstm_[0-9]/gi;
17+
const regexWeights = /weights|weight|kernel|kernels|w/gi;
18+
const regexFullyConnected = /softmax/gi;
2019

2120
class LSTM {
22-
constructor(modelPath, callback) {
23-
this.modelPath = modelPath;
21+
constructor(model, callback) {
2422
this.ready = false;
25-
this.indices_char = {};
26-
this.char_indices = {};
27-
28-
this.loaders = [
29-
this.loadFile('indices_char'),
30-
this.loadFile('char_indices'),
31-
];
32-
33-
Promise
34-
.all(this.loaders)
35-
.then(() => tf.loadModel(`${this.modelPath}/model.json`))
36-
.then((model) => { this.model = model; })
37-
.then(() => callback());
23+
this.model = {};
24+
this.cellsAmount = 0;
25+
this.vocab = {};
26+
this.vocabSize = 0;
27+
this.defaults = {
28+
seed: 'a',
29+
length: 20,
30+
temperature: 0.5,
31+
};
32+
this.loadCheckpoints(model, callback);
3833
}
3934

40-
loadFile(type) {
41-
fetch(`${this.modelPath}/${type}.json`)
42-
.then(response => response.json())
43-
.then((json) => { this[type] = json; })
44-
.catch(error => console.error(`Error when loading the model ${error}`));
35+
loadCheckpoints(path, callback) {
36+
const reader = new CheckpointLoader(path);
37+
reader.getAllVariables().then(async (vars) => {
38+
Object.keys(vars).forEach((key) => {
39+
if (key.match(regexCell)) {
40+
if (key.match(regexWeights)) {
41+
this.model[`Kernel_${key.match(/[0-9]/)[0]}`] = vars[key];
42+
this.cellsAmount += 1;
43+
} else {
44+
this.model[`Bias_${key.match(/[0-9]/)[0]}`] = vars[key];
45+
}
46+
} else if (key.match(regexFullyConnected)) {
47+
if (key.match(regexWeights)) {
48+
this.model.fullyConnectedWeights = vars[key];
49+
} else {
50+
this.model.fullyConnectedBiases = vars[key];
51+
}
52+
} else {
53+
this.model[key] = vars[key];
54+
}
55+
});
56+
this.loadVocab(path, callback);
57+
});
4558
}
4659

47-
/* eslint max-len: ["error", { "code": 180 }] */
48-
async generate(options = {}, callback = () => {}) {
49-
this.length = options.length || DEFAULTS.length;
50-
this.seed = options.seed || Object.keys(this.char_indices)[Math.floor(Math.random() * Object.keys(this.char_indices).length)];
51-
this.temperature = options.temperature || DEFAULTS.temperature;
52-
this.inputLength = options.inputLength || DEFAULTS.inputLength;
53-
let seed = this.seed;
54-
let generated = '';
55-
56-
/* eslint no-loop-func: 0 */
57-
for (let i = 0; i < this.length; i += 1) {
58-
const indexTensor = tf.tidy(() => {
59-
const input = this.convert(seed);
60-
const prediction = this.model.predict(input).squeeze();
61-
return LSTM.sample(prediction, this.temperature);
60+
loadVocab(file, callback) {
61+
fetch(`${file}/vocab.json`)
62+
.then(response => response.json())
63+
.then((json) => {
64+
this.vocab = json;
65+
this.vocabSize = Object.keys(json).length;
66+
this.ready = true;
67+
callback();
68+
}).catch((error) => {
69+
console.error(`There has been a problem loading the vocab: ${error.message}`);
6270
});
63-
const index = await indexTensor.data();
64-
indexTensor.dispose();
65-
seed += this.indices_char[index];
66-
generated += this.indices_char[index];
67-
await tf.nextFrame();
68-
}
69-
callback(generated);
7071
}
7172

72-
convert(input) {
73-
let sentence = input.toLowerCase();
74-
sentence = sentence.split('').filter(x => x in this.char_indices).join('');
75-
if (sentence.length < this.inputLength) {
76-
sentence = sentence.padStart(this.inputLength);
77-
} else if (sentence.length > this.inputLength) {
78-
sentence = sentence.substring(sentence.length - this.inputLength);
79-
}
80-
const buffer = tf.buffer([1, this.inputLength, Object.keys(this.indices_char).length]);
81-
for (let i = 0; i < this.inputLength; i += 1) {
82-
const char = sentence.charAt(i);
83-
buffer.set(1, 0, i, this.char_indices[char]);
84-
}
85-
const result = buffer.toTensor();
86-
return result;
87-
}
73+
async generate(options, callback) {
74+
const seed = options.seed || this.defaults.seed;
75+
const length = +options.length || this.defaults.length;
76+
const temperature = +options.temperature || this.defaults.temperature;
77+
const results = [];
8878

89-
static sample(input, temperature) {
90-
return tf.tidy(() => {
91-
let prediction = input.log();
92-
const diversity = tf.scalar(temperature);
93-
prediction = prediction.div(diversity);
94-
prediction = prediction.exp();
95-
prediction = prediction.div(prediction.sum());
96-
prediction = prediction.mul(tf.randomUniform(prediction.shape));
97-
return prediction.argMax();
98-
});
79+
if (this.ready) {
80+
const forgetBias = tf.tensor(1.0);
81+
const LSTMCells = [];
82+
let c = [];
83+
let h = [];
84+
85+
const lstm = (i) => {
86+
const cell = (DATA, C, H) =>
87+
tf.basicLSTMCell(forgetBias, this.model[`Kernel_${i}`], this.model[`Bias_${i}`], DATA, C, H);
88+
return cell;
89+
};
90+
91+
for (let i = 0; i < this.cellsAmount; i += 1) {
92+
c.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4]));
93+
h.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4]));
94+
LSTMCells.push(lstm(i));
95+
}
96+
97+
const userInput = Array.from(seed);
98+
99+
const encodedInput = [];
100+
userInput.forEach((char, ind) => {
101+
if (ind < 100) {
102+
encodedInput.push(this.vocab[char]);
103+
}
104+
});
105+
106+
let current = 0;
107+
let input = encodedInput[current];
108+
109+
for (let i = 0; i < userInput.length + length; i += 1) {
110+
const onehotBuffer = tf.buffer([1, this.vocabSize]);
111+
onehotBuffer.set(1.0, 0, input);
112+
const onehot = onehotBuffer.toTensor();
113+
let output;
114+
if (this.model.embedding) {
115+
const embedded = tf.matMul(onehot, this.model.embedding);
116+
output = tf.multiRNNCell(LSTMCells, embedded, c, h);
117+
} else {
118+
output = tf.multiRNNCell(LSTMCells, onehot, c, h);
119+
}
120+
121+
c = output[0];
122+
h = output[1];
123+
124+
const outputH = h[1];
125+
const weightedResult = tf.matMul(outputH, this.model.fullyConnectedWeights);
126+
const logits = tf.add(weightedResult, this.model.fullyConnectedBiases);
127+
const divided = tf.div(logits, tf.tensor(temperature));
128+
const probabilities = tf.exp(divided);
129+
const normalized = await tf.div(
130+
probabilities,
131+
tf.sum(probabilities),
132+
).data();
133+
134+
const sampledResult = sampleFromDistribution(normalized);
135+
if (userInput.length > current) {
136+
input = encodedInput[current];
137+
current += 1;
138+
} else {
139+
input = sampledResult;
140+
results.push(sampledResult);
141+
}
142+
}
143+
144+
let generated = '';
145+
results.forEach((char) => {
146+
const mapped = Object.keys(this.vocab).find(key => this.vocab[key] === char);
147+
if (mapped) {
148+
generated += mapped;
149+
}
150+
});
151+
callback({ generated });
152+
}
99153
}
100154
}
101155

0 commit comments

Comments
 (0)