Skip to content

Commit 22822a3

Browse files
committed
LSTM as function
1 parent 9a57672 commit 22822a3

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/LSTM/index.js

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ const DEFAULTS = {
1818
temperature: 0.5,
1919
};
2020

21-
class LSTMGenerator {
22-
constructor(modelPath = './', callback = () => {}) {
21+
class LSTM {
22+
constructor(modelPath, callback) {
2323
this.modelPath = modelPath;
2424
this.ready = false;
2525
this.indices_char = {};
@@ -58,7 +58,7 @@ class LSTMGenerator {
5858
const indexTensor = tf.tidy(() => {
5959
const input = this.convert(seed);
6060
const prediction = this.model.predict(input).squeeze();
61-
return LSTMGenerator.sample(prediction, this.temperature);
61+
return LSTM.sample(prediction, this.temperature);
6262
});
6363
const index = await indexTensor.data();
6464
indexTensor.dispose();
@@ -99,4 +99,6 @@ class LSTMGenerator {
9999
}
100100
}
101101

102+
const LSTMGenerator = (modelPath = './', callback = () => {}) => new LSTM(modelPath, callback);
103+
102104
export default LSTMGenerator;

0 commit comments

Comments
 (0)