|
7 | 7 | /* eslint no-await-in-loop: "off" */
|
8 | 8 | /*
|
9 | 9 | A LSTM Generator: Run inference mode for a pre-trained LSTM.
|
10 |
| -Heavily derived from https://github.com/reiinakano/tfjs-lstm-text-generation/ |
11 | 10 | */
|
12 | 11 |
|
13 | 12 | import * as tf from '@tensorflow/tfjs';
|
| 13 | +import sampleFromDistribution from './../utils/sample'; |
| 14 | +import CheckpointLoader from '../utils/checkpointLoader'; |
14 | 15 |
|
15 |
| -const DEFAULTS = { |
16 |
| - inputLength: 40, |
17 |
| - length: 20, |
18 |
| - temperature: 0.5, |
19 |
| -}; |
| 16 | +const regexCell = /cell_[0-9]|lstm_[0-9]/gi; |
| 17 | +const regexWeights = /weights|weight|kernel|kernels|w/gi; |
| 18 | +const regexFullyConnected = /softmax/gi; |
20 | 19 |
|
21 | 20 | class LSTM {
|
22 |
| - constructor(modelPath, callback) { |
23 |
| - this.modelPath = modelPath; |
| 21 | + constructor(model, callback) { |
24 | 22 | this.ready = false;
|
25 |
| - this.indices_char = {}; |
26 |
| - this.char_indices = {}; |
27 |
| - |
28 |
| - this.loaders = [ |
29 |
| - this.loadFile('indices_char'), |
30 |
| - this.loadFile('char_indices'), |
31 |
| - ]; |
32 |
| - |
33 |
| - Promise |
34 |
| - .all(this.loaders) |
35 |
| - .then(() => tf.loadModel(`${this.modelPath}/model.json`)) |
36 |
| - .then((model) => { this.model = model; }) |
37 |
| - .then(() => callback()); |
| 23 | + this.model = {}; |
| 24 | + this.cellsAmount = 0; |
| 25 | + this.vocab = {}; |
| 26 | + this.vocabSize = 0; |
| 27 | + this.defaults = { |
| 28 | + seed: 'a', |
| 29 | + length: 20, |
| 30 | + temperature: 0.5, |
| 31 | + }; |
| 32 | + this.loadCheckpoints(model, callback); |
38 | 33 | }
|
39 | 34 |
|
40 |
| - loadFile(type) { |
41 |
| - fetch(`${this.modelPath}/${type}.json`) |
42 |
| - .then(response => response.json()) |
43 |
| - .then((json) => { this[type] = json; }) |
44 |
| - .catch(error => console.error(`Error when loading the model ${error}`)); |
| 35 | + loadCheckpoints(path, callback) { |
| 36 | + const reader = new CheckpointLoader(path); |
| 37 | + reader.getAllVariables().then(async (vars) => { |
| 38 | + Object.keys(vars).forEach((key) => { |
| 39 | + if (key.match(regexCell)) { |
| 40 | + if (key.match(regexWeights)) { |
| 41 | + this.model[`Kernel_${key.match(/[0-9]/)[0]}`] = vars[key]; |
| 42 | + this.cellsAmount += 1; |
| 43 | + } else { |
| 44 | + this.model[`Bias_${key.match(/[0-9]/)[0]}`] = vars[key]; |
| 45 | + } |
| 46 | + } else if (key.match(regexFullyConnected)) { |
| 47 | + if (key.match(regexWeights)) { |
| 48 | + this.model.fullyConnectedWeights = vars[key]; |
| 49 | + } else { |
| 50 | + this.model.fullyConnectedBiases = vars[key]; |
| 51 | + } |
| 52 | + } else { |
| 53 | + this.model[key] = vars[key]; |
| 54 | + } |
| 55 | + }); |
| 56 | + this.loadVocab(path, callback); |
| 57 | + }); |
45 | 58 | }
|
46 | 59 |
|
47 |
| - /* eslint max-len: ["error", { "code": 180 }] */ |
48 |
| - async generate(options = {}, callback = () => {}) { |
49 |
| - this.length = options.length || DEFAULTS.length; |
50 |
| - this.seed = options.seed || Object.keys(this.char_indices)[Math.floor(Math.random() * Object.keys(this.char_indices).length)]; |
51 |
| - this.temperature = options.temperature || DEFAULTS.temperature; |
52 |
| - this.inputLength = options.inputLength || DEFAULTS.inputLength; |
53 |
| - let seed = this.seed; |
54 |
| - let generated = ''; |
55 |
| - |
56 |
| - /* eslint no-loop-func: 0 */ |
57 |
| - for (let i = 0; i < this.length; i += 1) { |
58 |
| - const indexTensor = tf.tidy(() => { |
59 |
| - const input = this.convert(seed); |
60 |
| - const prediction = this.model.predict(input).squeeze(); |
61 |
| - return LSTM.sample(prediction, this.temperature); |
| 60 | + loadVocab(file, callback) { |
| 61 | + fetch(`${file}/vocab.json`) |
| 62 | + .then(response => response.json()) |
| 63 | + .then((json) => { |
| 64 | + this.vocab = json; |
| 65 | + this.vocabSize = Object.keys(json).length; |
| 66 | + this.ready = true; |
| 67 | + callback(); |
| 68 | + }).catch((error) => { |
| 69 | + console.error(`There has been a problem loading the vocab: ${error.message}`); |
62 | 70 | });
|
63 |
| - const index = await indexTensor.data(); |
64 |
| - indexTensor.dispose(); |
65 |
| - seed += this.indices_char[index]; |
66 |
| - generated += this.indices_char[index]; |
67 |
| - await tf.nextFrame(); |
68 |
| - } |
69 |
| - callback(generated); |
70 | 71 | }
|
71 | 72 |
|
72 |
| - convert(input) { |
73 |
| - let sentence = input.toLowerCase(); |
74 |
| - sentence = sentence.split('').filter(x => x in this.char_indices).join(''); |
75 |
| - if (sentence.length < this.inputLength) { |
76 |
| - sentence = sentence.padStart(this.inputLength); |
77 |
| - } else if (sentence.length > this.inputLength) { |
78 |
| - sentence = sentence.substring(sentence.length - this.inputLength); |
79 |
| - } |
80 |
| - const buffer = tf.buffer([1, this.inputLength, Object.keys(this.indices_char).length]); |
81 |
| - for (let i = 0; i < this.inputLength; i += 1) { |
82 |
| - const char = sentence.charAt(i); |
83 |
| - buffer.set(1, 0, i, this.char_indices[char]); |
84 |
| - } |
85 |
| - const result = buffer.toTensor(); |
86 |
| - return result; |
87 |
| - } |
| 73 | + async generate(options, callback) { |
| 74 | + const seed = options.seed || this.defaults.seed; |
| 75 | + const length = +options.length || this.defaults.length; |
| 76 | + const temperature = +options.temperature || this.defaults.temperature; |
| 77 | + const results = []; |
88 | 78 |
|
89 |
| - static sample(input, temperature) { |
90 |
| - return tf.tidy(() => { |
91 |
| - let prediction = input.log(); |
92 |
| - const diversity = tf.scalar(temperature); |
93 |
| - prediction = prediction.div(diversity); |
94 |
| - prediction = prediction.exp(); |
95 |
| - prediction = prediction.div(prediction.sum()); |
96 |
| - prediction = prediction.mul(tf.randomUniform(prediction.shape)); |
97 |
| - return prediction.argMax(); |
98 |
| - }); |
| 79 | + if (this.ready) { |
| 80 | + const forgetBias = tf.tensor(1.0); |
| 81 | + const LSTMCells = []; |
| 82 | + let c = []; |
| 83 | + let h = []; |
| 84 | + |
| 85 | + const lstm = (i) => { |
| 86 | + const cell = (DATA, C, H) => |
| 87 | + tf.basicLSTMCell(forgetBias, this.model[`Kernel_${i}`], this.model[`Bias_${i}`], DATA, C, H); |
| 88 | + return cell; |
| 89 | + }; |
| 90 | + |
| 91 | + for (let i = 0; i < this.cellsAmount; i += 1) { |
| 92 | + c.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4])); |
| 93 | + h.push(tf.zeros([1, this.model[`Bias_${i}`].shape[0] / 4])); |
| 94 | + LSTMCells.push(lstm(i)); |
| 95 | + } |
| 96 | + |
| 97 | + const userInput = Array.from(seed); |
| 98 | + |
| 99 | + const encodedInput = []; |
| 100 | + userInput.forEach((char, ind) => { |
| 101 | + if (ind < 100) { |
| 102 | + encodedInput.push(this.vocab[char]); |
| 103 | + } |
| 104 | + }); |
| 105 | + |
| 106 | + let current = 0; |
| 107 | + let input = encodedInput[current]; |
| 108 | + |
| 109 | + for (let i = 0; i < userInput.length + length; i += 1) { |
| 110 | + const onehotBuffer = tf.buffer([1, this.vocabSize]); |
| 111 | + onehotBuffer.set(1.0, 0, input); |
| 112 | + const onehot = onehotBuffer.toTensor(); |
| 113 | + let output; |
| 114 | + if (this.model.embedding) { |
| 115 | + const embedded = tf.matMul(onehot, this.model.embedding); |
| 116 | + output = tf.multiRNNCell(LSTMCells, embedded, c, h); |
| 117 | + } else { |
| 118 | + output = tf.multiRNNCell(LSTMCells, onehot, c, h); |
| 119 | + } |
| 120 | + |
| 121 | + c = output[0]; |
| 122 | + h = output[1]; |
| 123 | + |
| 124 | + const outputH = h[1]; |
| 125 | + const weightedResult = tf.matMul(outputH, this.model.fullyConnectedWeights); |
| 126 | + const logits = tf.add(weightedResult, this.model.fullyConnectedBiases); |
| 127 | + const divided = tf.div(logits, tf.tensor(temperature)); |
| 128 | + const probabilities = tf.exp(divided); |
| 129 | + const normalized = await tf.div( |
| 130 | + probabilities, |
| 131 | + tf.sum(probabilities), |
| 132 | + ).data(); |
| 133 | + |
| 134 | + const sampledResult = sampleFromDistribution(normalized); |
| 135 | + if (userInput.length > current) { |
| 136 | + input = encodedInput[current]; |
| 137 | + current += 1; |
| 138 | + } else { |
| 139 | + input = sampledResult; |
| 140 | + results.push(sampledResult); |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + let generated = ''; |
| 145 | + results.forEach((char) => { |
| 146 | + const mapped = Object.keys(this.vocab).find(key => this.vocab[key] === char); |
| 147 | + if (mapped) { |
| 148 | + generated += mapped; |
| 149 | + } |
| 150 | + }); |
| 151 | + callback({ generated }); |
| 152 | + } |
99 | 153 | }
|
100 | 154 | }
|
101 | 155 |
|
|
0 commit comments