Skip to content

Commit 383bc52

Browse files
authored
Repetition penalty (#100)
1 parent fb6d0e8 commit 383bc52

File tree

1 file changed

+21
-5
lines changed

1 file changed

+21
-5
lines changed

web/llm_chat.js

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class Conversation {
6969
throw Error("needs to call getPromptArray for the first message");
7070
}
7171
if (this.separator_style == "Two") {
72-
let ret = [this.seps[this.seps.length - 1]];
72+
let ret = [];
7373
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
7474
const item = this.messages[i];
7575
const role = item[0];
@@ -167,6 +167,8 @@ class LLMChatPipeline {
167167

168168
this.temperature = config.temperature;
169169
this.top_p = config.top_p;
170+
this.repetitionPenalty = config.repetition_penalty
171+
this.appeared_tokens = new Set();
170172

171173
this.meanGenLength = config.mean_gen_len;
172174
this.streamInterval = 1;
@@ -268,7 +270,16 @@ class LLMChatPipeline {
268270
this.#updateLogitsOnCPU(logits);
269271
this.tvm.endScope();
270272
await this.device.sync();
271-
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
273+
if (this.repetitionPenalty < 1.0 + 1e-6) {
274+
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
275+
} else {
276+
this.tvm.beginScope();
277+
var appeared_tokens_ndarray = this.tvm.empty([1, this.appeared_tokens.size], "int32", this.tvm.cpu());
278+
appeared_tokens_ndarray.copyFrom(Array.from(this.appeared_tokens));
279+
this.tvm.applyRepetitionPenalty(this.logitsOnCPU, appeared_tokens_ndarray, this.repetitionPenalty);
280+
this.tvm.endScope();
281+
return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p);
282+
}
272283
}
273284

274285
async getInputTokens() {
@@ -360,6 +371,7 @@ class LLMChatPipeline {
360371
throw Error("Too small window size config");
361372
}
362373
let step = 0;
374+
var stop = false;
363375
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
364376
this.tvm.beginScope();
365377
var inputData;
@@ -375,22 +387,26 @@ class LLMChatPipeline {
375387
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
376388
);
377389
this.tvm.endScope();
390+
if (stop) {
391+
break;
392+
}
378393

379394
const nextToken = await this.sampleTokenFromLogits(logits, this.temperature, this.top_p);
380395
logits.dispose();
381396

382397
tokens.push(nextToken);
398+
this.appeared_tokens.add(nextToken);
383399
const outputTokens = tokens.slice(inputTokenLength);
384400
outputPrompt = this.tokenizer.decode(outputTokens);
385401

386402
if (this.stopTokens.includes(nextToken)) {
387-
break;
403+
stop = true;
388404
}
389405

390406
const stopPos = outputPrompt.lastIndexOf(stopStr);
391407
if (stopPos != -1) {
392408
outputPrompt = outputPrompt.substring(0, stopPos);
393-
break;
409+
stop = true;
394410
}
395411
let tend = performance.now();
396412
if (step != 0) {
@@ -405,7 +421,7 @@ class LLMChatPipeline {
405421
callbackUpdateResponse(step, outputPrompt);
406422
}
407423
}
408-
this.kvCacheLength += tokens.length - 1;
424+
this.kvCacheLength += tokens.length;
409425
this.conversation.messages[this.conversation.messages.length - 1][1] = outputPrompt;
410426
return outputPrompt;
411427
}

0 commit comments

Comments
 (0)