-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbrain.js
More file actions
113 lines (91 loc) · 3.48 KB
/
brain.js
File metadata and controls
113 lines (91 loc) · 3.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
const tf = require('@tensorflow/tfjs')
const tfn = require('@tensorflow/tfjs-node')
const token = require('./token')
const inputWord = require('./mappings/input-word2idx')
const wordContext = require('./mappings/word-context')
const targetWord = require('./mappings/target-word2idx')
const targetId = require('./mappings/target-idx2word')
const decoderModel = tfn.io.fileSystem('./decoder-model/model.json')
const encoderModel = tfn.io.fileSystem('./encoder-model/model.json')
module.exports = class Brain {
constructor() {
Promise.all([
tf.loadLayersModel(decoderModel),
tf.loadLayersModel(encoderModel)
]).then(([decoder, encoder]) => {
this.decoder = decoder
this.encoder = encoder
this.enableGeneration()
})
}
enableGeneration() {
return 'I`m aliveeeeee!!!'
}
async sendChat(inputText) {
console.log(inputText)
const states = tf.tidy(() => {
const input = this.convertSentenceToTensor(inputText)
return this.encoder.predict(input)
})
this.decoder.layers[1].resetStates(states)
let responseTokens = []
let terminate = false
let nextTokenID = targetWord['<SOS>']
let numPredicted = 0
while (!terminate) {
const outputTokenTensor = tf.tidy(() => {
const input = this.generateDecoderInputFromTokenID(nextTokenID)
const prediction = this.decoder.predict(input)
return this.sample(prediction.squeeze())
})
const outputToken = await outputTokenTensor.data()
outputTokenTensor.dispose()
nextTokenID = Math.round(outputToken[0])
const word = targetId[nextTokenID]
numPredicted++
console.log(outputToken, nextTokenID, word)
if (word !== '<EOS>' && word !== '<SOS>') {
responseTokens.push(word)
}
if (word === '<EOS>'
|| numPredicted >= wordContext.decoder_max_seq_length) {
terminate = true
}
await tf.nextFrame()
}
const answer = this.convertTokensToSentence(responseTokens)
states[0].dispose()
states[1].dispose()
return await answer
}
generateDecoderInputFromTokenID(tokenID) {
const buffer = tf.buffer([1, 1, wordContext.num_decoder_tokens])
buffer.set(1, 0, 0, tokenID)
return buffer.toTensor()
}
sample(prediction) {
return tf.tidy(() => prediction.argMax())
}
convertSentenceToTensor(sentence) {
let inputWordIds = []
token.tokenizer(sentence).map((word) => {
word = word.toLowerCase()
let idx = '1'
if (word in inputWord) {
idx = inputWord[word]
}
inputWordIds.push(Number(idx))
})
if (inputWordIds.length < wordContext.encoder_max_seq_length) {
let sequence = new Array(wordContext.encoder_max_seq_length-inputWordIds.length+1)
.join('0').split('').map(Number)
inputWordIds = [...sequence, ...inputWordIds]
} else {
inputWordIds = inputWordIds.slice(0, wordContext.encoder_max_seq_length)
}
return tf.tensor2d(inputWordIds, [1, wordContext.encoder_max_seq_length])
}
convertTokensToSentence(tokens) {
return tokens.join(' ')
}
}