@@ -69,7 +69,7 @@ class Conversation {
69
69
throw Error ( "needs to call getPromptArray for the first message" ) ;
70
70
}
71
71
if ( this . separator_style == "Two" ) {
72
- let ret = [ this . seps [ this . seps . length - 1 ] ] ;
72
+ let ret = [ ] ;
73
73
for ( let i = this . messages . length - 2 ; i < this . messages . length ; ++ i ) {
74
74
const item = this . messages [ i ] ;
75
75
const role = item [ 0 ] ;
@@ -167,6 +167,8 @@ class LLMChatPipeline {
167
167
168
168
this . temperature = config . temperature ;
169
169
this . top_p = config . top_p ;
170
+ this . repetitionPenalty = config . repetition_penalty
171
+ this . appeared_tokens = new Set ( ) ;
170
172
171
173
this . meanGenLength = config . mean_gen_len ;
172
174
this . streamInterval = 1 ;
@@ -268,7 +270,16 @@ class LLMChatPipeline {
268
270
this . #updateLogitsOnCPU( logits ) ;
269
271
this . tvm . endScope ( ) ;
270
272
await this . device . sync ( ) ;
271
- return this . tvm . sampleTopPFromLogits ( this . logitsOnCPU , temperature , top_p ) ;
273
+ if ( this . repetitionPenalty < 1.0 + 1e-6 ) {
274
+ return this . tvm . sampleTopPFromLogits ( this . logitsOnCPU , temperature , top_p ) ;
275
+ } else {
276
+ this . tvm . beginScope ( ) ;
277
+ var appeared_tokens_ndarray = this . tvm . empty ( [ 1 , this . appeared_tokens . size ] , "int32" , this . tvm . cpu ( ) ) ;
278
+ appeared_tokens_ndarray . copyFrom ( Array . from ( this . appeared_tokens ) ) ;
279
+ this . tvm . applyRepetitionPenalty ( this . logitsOnCPU , appeared_tokens_ndarray , this . repetitionPenalty ) ;
280
+ this . tvm . endScope ( ) ;
281
+ return this . tvm . sampleTopPFromLogits ( this . logitsOnCPU , temperature , top_p ) ;
282
+ }
272
283
}
273
284
274
285
async getInputTokens ( ) {
@@ -360,6 +371,7 @@ class LLMChatPipeline {
360
371
throw Error ( "Too small window size config" ) ;
361
372
}
362
373
let step = 0 ;
374
+ var stop = false ;
363
375
for ( ; step < maxGenLen && this . kvCacheLength + inputTokenLength + step < this . maxWindowLength ; ++ step ) {
364
376
this . tvm . beginScope ( ) ;
365
377
var inputData ;
@@ -375,22 +387,26 @@ class LLMChatPipeline {
375
387
this . #forward( inputData , this . kvCacheLength + inputTokenLength + step )
376
388
) ;
377
389
this . tvm . endScope ( ) ;
390
+ if ( stop ) {
391
+ break ;
392
+ }
378
393
379
394
const nextToken = await this . sampleTokenFromLogits ( logits , this . temperature , this . top_p ) ;
380
395
logits . dispose ( ) ;
381
396
382
397
tokens . push ( nextToken ) ;
398
+ this . appeared_tokens . add ( nextToken ) ;
383
399
const outputTokens = tokens . slice ( inputTokenLength ) ;
384
400
outputPrompt = this . tokenizer . decode ( outputTokens ) ;
385
401
386
402
if ( this . stopTokens . includes ( nextToken ) ) {
387
- break ;
403
+ stop = true ;
388
404
}
389
405
390
406
const stopPos = outputPrompt . lastIndexOf ( stopStr ) ;
391
407
if ( stopPos != - 1 ) {
392
408
outputPrompt = outputPrompt . substring ( 0 , stopPos ) ;
393
- break ;
409
+ stop = true ;
394
410
}
395
411
let tend = performance . now ( ) ;
396
412
if ( step != 0 ) {
@@ -405,7 +421,7 @@ class LLMChatPipeline {
405
421
callbackUpdateResponse ( step , outputPrompt ) ;
406
422
}
407
423
}
408
- this . kvCacheLength += tokens . length - 1 ;
424
+ this . kvCacheLength += tokens . length ;
409
425
this . conversation . messages [ this . conversation . messages . length - 1 ] [ 1 ] = outputPrompt ;
410
426
return outputPrompt ;
411
427
}
0 commit comments