@@ -36,6 +36,57 @@ class Conversation {
36
36
return ret ;
37
37
}
38
38
39
+ /**
40
+ * Get prompt arrays that has not been fed as input
41
+ *
42
+ * @returns The prompt array.
43
+ */
44
+ getPromptArrayUnproccessed ( ) {
45
+ if ( this . seps . length == 0 ) {
46
+ throw Error ( "Need seps to work" )
47
+ }
48
+ if ( this . messages . length < 3 ) {
49
+ throw Error ( "needs to call getLastPromptArray for the first message" ) ;
50
+ }
51
+ let ret = [ this . seps [ this . seps . length - 1 ] ] ;
52
+ for ( let i = this . messages . length - 2 ; i < this . messages . length ; ++ i ) {
53
+ const item = this . messages [ i ] ;
54
+ const role = item [ 0 ] ;
55
+ const message = item [ 1 ] ;
56
+ if ( message !== undefined && message != "" ) {
57
+ ret . push ( role + ": " + message + this . seps [ i % this . seps . length ] ) ;
58
+ } else {
59
+ ret . push ( role + ":" ) ;
60
+ }
61
+ }
62
+ return ret ;
63
+
64
+ }
65
+
66
+ /**
67
+ * Get last prompt array with prefix as system.
68
+ *
69
+ * @returns The prompt array.
70
+ */
71
+ getLastPromptArray ( ) {
72
+ if ( this . seps . length == 0 ) {
73
+ throw Error ( "Need seps to work" )
74
+ }
75
+ let ret = [ this . system + this . seps [ 0 ] ] ;
76
+
77
+ for ( let i = this . messages . length - 2 ; i < this . messages . length ; ++ i ) {
78
+ const item = this . messages [ i ] ;
79
+ const role = item [ 0 ] ;
80
+ const message = item [ 1 ] ;
81
+ if ( message !== undefined && message != "" ) {
82
+ ret . push ( role + ": " + message + this . seps [ i % this . seps . length ] ) ;
83
+ } else {
84
+ ret . push ( role + ":" ) ;
85
+ }
86
+ }
87
+ return ret ;
88
+ }
89
+
39
90
reset ( ) {
40
91
this . messages = [ ] ;
41
92
}
@@ -52,12 +103,12 @@ class Conversation {
52
103
function defaultConversation ( maxWindowLength = 512 ) {
53
104
return new Conversation ( {
54
105
system : "A chat between a curious user and an artificial intelligence assistant. " +
55
- "The assistant gives helpful, detailed, and polite answers to the user's questions." ,
106
+ "The assistant gives helpful, detailed, and polite answers to the user's questions." ,
56
107
roles : [ "USER" , "ASSISTANT" ] ,
57
108
maxWindowLength : maxWindowLength ,
58
109
messages : [ ] ,
59
110
offset : 0 ,
60
- seps :[ " " , "</s>" ] ,
111
+ seps : [ " " , "</s>" ] ,
61
112
} ) ;
62
113
} ;
63
114
@@ -120,6 +171,9 @@ class LLMChatPipeline {
120
171
this . kvCache = this . tvm . detachFromCurrentScope ( this . tvm . makeTVMArray ( kvList ) ) ;
121
172
// fill with pad token
122
173
this . logitsOnCPU = undefined ;
174
+
175
+ this . kvCacheLength = 0 ;
176
+ this . clearCache = true
123
177
}
124
178
125
179
@@ -167,7 +221,7 @@ class LLMChatPipeline {
167
221
this . tvm . empty ( logits . shape , logits . dtype , this . tvm . cpu ( ) )
168
222
) ;
169
223
} else {
170
- if ( logits . shape [ 0 ] != this . logitsOnCPU . shape [ 0 ] ) {
224
+ if ( logits . shape [ 0 ] != this . logitsOnCPU . shape [ 0 ] ) {
171
225
throw Error ( "We expect the size of logits to remain unchanged" ) ;
172
226
}
173
227
}
@@ -183,35 +237,56 @@ class LLMChatPipeline {
183
237
}
184
238
185
239
async getInputTokens ( ) {
186
- const tokens = [ this . bosTokenId ] ;
187
- const prompts = this . conversation . getPromptArray ( ) ;
240
+ let tokens = [ this . bosTokenId ] ;
241
+ let prompts = ""
242
+ if ( this . conversation . messages . length <= 2 ) {
243
+ prompts = this . conversation . getPromptArray ( ) ;
244
+ } else {
245
+ tokens . pop ( ) ;
246
+ prompts = this . conversation . getPromptArrayUnproccessed ( ) ;
247
+ }
188
248
tokens . push ( ...await this . tokenizer . encodeIds ( prompts [ 0 ] ) ) ;
189
-
190
249
let ctxLength = tokens . length ;
191
- const context = [ ] ;
250
+ let context = [ ] ;
251
+ let need_shift_window = false ;
192
252
for ( let i = prompts . length - 1 ; i > 0 ; -- i ) {
193
253
const encoded = this . tokenizer . encodeIds ( prompts [ i ] ) ;
194
254
ctxLength += encoded . length ;
195
- if ( ctxLength + this . meanGenLength >= this . maxWindowLength && i + 2 < prompts . length ) {
196
- this . logger ( "Shift window at " + i ) ;
255
+ if ( this . kvCacheLength + ctxLength + this . meanGenLength >= this . maxWindowLength ) {
256
+ need_shift_window = true ;
197
257
break ;
198
258
}
199
259
context . unshift ( encoded ) ;
200
260
}
201
- const followMessage = [ ] ;
202
- for ( const ctx of context ) {
203
- followMessage . push ( ...ctx ) ;
261
+ if ( ! need_shift_window ) {
262
+ for ( const ctx of context ) {
263
+ tokens . push ( ...ctx ) ;
264
+ }
265
+ return tokens ;
204
266
}
205
-
206
- if ( followMessage . length + tokens . length + this . meanGenLength >= this . maxWindowLength ) {
207
- const maxMsgLen = this . maxWindowLength - tokens . length - this . meanGenLength ;
208
- if ( maxMsgLen < this . meanGenLength ) {
209
- throw Error ( "Too small window config tokens.length=" + tokens . length ) ;
267
+ // need shift window and re-encode
268
+ this . logger ( "need shift window" )
269
+ this . kvCacheLength = 0 ;
270
+ this . clearCache = true ;
271
+ // abandon all tokens we collected
272
+ tokens = [ this . bosTokenId ]
273
+ let all_prompts = this . conversation . getPromptArray ( ) ;
274
+ tokens . push ( ...await this . tokenizer . encodeIds ( all_prompts [ 0 ] ) ) ;
275
+ context = [ ] ;
276
+ ctxLength = tokens . length ;
277
+ //only keep 10% of the window context
278
+ const fill_factor = 0.1
279
+ for ( let i = all_prompts . length - 1 ; i > 0 ; -- i ) {
280
+ const encoded = this . tokenizer . encodeIds ( all_prompts [ i ] ) ;
281
+ ctxLength += encoded . length ;
282
+ if ( ctxLength >= fill_factor * this . maxWindowLength && i + 2 < all_prompts . length ) {
283
+ break ;
210
284
}
211
- this . logger ( "Slice message " + followMessage . length + " to " + maxMsgLen ) ;
212
- followMessage = followMessage . slice ( followMessage . length - maxMsgLen ) ;
285
+ context . unshift ( encoded ) ;
286
+ }
287
+ for ( const ctx of context ) {
288
+ tokens . push ( ...ctx ) ;
213
289
}
214
- tokens . push ( ...followMessage ) ;
215
290
if ( tokens . length + this . meanGenLength >= this . maxWindowLength ) {
216
291
throw Error ( "Exceed max window length curr=" + tokens . length ) ;
217
292
}
@@ -235,16 +310,18 @@ class LLMChatPipeline {
235
310
const inputTokenLength = tokens . length ;
236
311
237
312
var outputPrompt = "" ;
238
- this . #clearKVCache( ) ;
313
+ if ( this . clearCache ) {
314
+ this . #clearKVCache( ) ;
315
+ this . clearCache = false ;
316
+ }
239
317
const maxGenLen = Math . min ( this . maxGenLength , this . maxWindowLength - tokens . length ) ;
240
318
if ( maxGenLen < this . meanGenLength ) {
241
319
throw Error ( "Too small window size config" ) ;
242
320
}
243
-
244
- for ( let step = 0 ; step < maxGenLen ; ++ step ) {
321
+ let step = 0 ;
322
+ for ( ; step < maxGenLen && this . kvCacheLength + inputTokenLength + step < this . maxWindowLength ; ++ step ) {
245
323
this . tvm . beginScope ( ) ;
246
324
var inputData ;
247
-
248
325
let tstart = performance . now ( ) ;
249
326
if ( step == 0 ) {
250
327
inputData = this . tvm . empty ( [ 1 , tokens . length ] , "int32" , this . device ) ;
@@ -254,7 +331,7 @@ class LLMChatPipeline {
254
331
inputData . copyFrom ( tokens . slice ( tokens . length - 1 ) ) ;
255
332
}
256
333
const logits = this . tvm . detachFromCurrentScope (
257
- this . #forward( inputData , inputTokenLength + step )
334
+ this . #forward( inputData , this . kvCacheLength + inputTokenLength + step )
258
335
) ;
259
336
this . tvm . endScope ( ) ;
260
337
@@ -285,6 +362,7 @@ class LLMChatPipeline {
285
362
callbackUpdateResponse ( step , outputPrompt ) ;
286
363
}
287
364
}
365
+ this . kvCacheLength += tokens . length - 1 ;
288
366
this . conversation . messages [ this . conversation . messages . length - 1 ] [ 1 ] = outputPrompt ;
289
367
return outputPrompt ;
290
368
}
@@ -358,12 +436,12 @@ class LLMChatInstance {
358
436
this . logger = console . log ;
359
437
this . debugTest = false ;
360
438
}
361
- /**
362
- * Initialize TVM
363
- * @param wasmUrl URL to wasm source.
364
- * @param cacheUrl URL to NDArray cache.
365
- * @param logger Custom logger.
366
- */
439
+ /**
440
+ * Initialize TVM
441
+ * @param wasmUrl URL to wasm source.
442
+ * @param cacheUrl URL to NDArray cache.
443
+ * @param logger Custom logger.
444
+ */
367
445
async #asyncInitTVM( wasmUrl , cacheUrl ) {
368
446
if ( this . tvm !== undefined ) {
369
447
return ;
@@ -395,7 +473,7 @@ class LLMChatInstance {
395
473
this . reset ( ) ;
396
474
throw Error ( "This browser env do not support WebGPU" ) ;
397
475
}
398
- } catch ( err ) {
476
+ } catch ( err ) {
399
477
this . appendMessage ( "error" , "Find an error initializing the WebGPU device " + err . toString ( ) ) ;
400
478
console . log ( err . stack ) ;
401
479
this . reset ( ) ;
@@ -444,7 +522,7 @@ class LLMChatInstance {
444
522
// initialize UX and tokenizer
445
523
const tokenizer = await tvmjsGlobalEnv . sentencePieceProcessor ( this . config . tokenizer ) ;
446
524
this . pipeline = this . tvm . withNewScope ( ( ) => {
447
- return new LLMChatPipeline ( this . tvm , tokenizer , this . tvm . cacheMetadata , this . config ) ;
525
+ return new LLMChatPipeline ( this . tvm , tokenizer , this . tvm . cacheMetadata , this . config ) ;
448
526
} ) ;
449
527
await this . pipeline . asyncLoadWebGPUPiplines ( ) ;
450
528
this . updateLastMessage ( "init" , "All initialization finished." ) ;
@@ -521,7 +599,7 @@ class LLMChatInstance {
521
599
522
600
try {
523
601
await this . asyncInit ( ) ;
524
- } catch ( err ) {
602
+ } catch ( err ) {
525
603
this . appendMessage ( "error" , "Init error, " + err . toString ( ) ) ;
526
604
console . log ( err . stack ) ;
527
605
this . reset ( ) ;
0 commit comments