Skip to content

Commit 35c42f2

Browse files
authored
Enable KV cache reuse across conversations (#51)
Currently, when user enter an input, the behavior is to clear KV cache and collect all previous conversation history as input prompt to the model. It becomes extremely slow when the history is large. But clearly, there exists some redundant kv computation for previous conversations. We could utilize the kv cache from the last round of conversation to avoid kv recomputation. This PR enables such optimization.
1 parent 3d62bc9 commit 35c42f2

File tree

8 files changed

+501
-56
lines changed

8 files changed

+501
-56
lines changed

build.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_models(config, model):
8888
for gv in mod.functions:
8989
func = mod[gv]
9090
if isinstance(func, relax.Function):
91-
mod[gv] = func.with_attr("tir_var_upper_bound", {"n": config.max_sequence_length})
91+
mod[gv] = func.with_attr("tir_var_upper_bound", {"n": config.max_sequence_length, "m": config.max_sequence_length})
9292

9393
return mod
9494
else:

chat.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _parse_args():
1717
args.add_argument("--debug-dump", action="store_true", default=False)
1818
args.add_argument("--artifact-path", type=str, default="dist")
1919
args.add_argument("--model", type=str, default="vicuna-7b-v1")
20-
args.add_argument("--max-gen-len", type=int, default=128)
20+
args.add_argument("--max-gen-len", type=int, default=2048)
2121
args.add_argument("--run-torch-model", action="store_true", default=False)
2222
parsed = args.parse_args()
2323
parsed.model_path = os.path.join(parsed.artifact_path, "models", parsed.model)
@@ -46,9 +46,11 @@ def generate(
4646
top_p: float = 0.95,
4747
stream_interval: int = 2,
4848
stop_str: str = None,
49+
add_bos = True,
4950
):
5051
prompt_tokens = self.tokenizer.encode(prompt)
51-
52+
if not add_bos:
53+
prompt_tokens = prompt_tokens[1:]
5254
total_len = max_gen_len + len(prompt_tokens)
5355
tokens = torch.full((1, total_len), self.tokenizer.pad_token_id).to(
5456
torch.int32
@@ -57,9 +59,9 @@ def generate(
5759
start_pos = len(prompt_tokens)
5860
for cur_pos in range(start_pos, total_len):
5961
if cur_pos == start_pos:
60-
logits = self.model(tokens[:, :cur_pos], cur_pos, clear_cache=True)
62+
logits = self.model(tokens[:, :cur_pos])
6163
else:
62-
logits = self.model(tokens[:, cur_pos - 1 : cur_pos], cur_pos)
64+
logits = self.model(tokens[:, cur_pos - 1 : cur_pos])
6365
logits = logits[:, -1, :]
6466
if temperature > 0:
6567
probs = torch.softmax(logits / temperature, dim=-1)
@@ -102,6 +104,7 @@ def chat(model_wrapper, args):
102104

103105
# Chat
104106
conv = conv_templates["vicuna_v1.1"].copy()
107+
add_bos = True
105108
while True:
106109
try:
107110
inp = input(f"{conv.roles[0]}: ")
@@ -113,14 +116,14 @@ def chat(model_wrapper, args):
113116

114117
conv.append_message(conv.roles[0], inp)
115118
conv.append_message(conv.roles[1], None)
116-
prompt = conv.get_prompt()
117-
119+
prompt = conv.get_prompt_unprocessed()
118120
print(f"{conv.roles[1]}: ", end="", flush=True)
119121
pre = 0
120122
for outputs in model_wrapper.generate(
121123
prompt,
122124
args.max_gen_len,
123125
stop_str=conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
126+
add_bos = add_bos,
124127
):
125128
outputs = outputs[len(prompt) + 1 :].strip()
126129
outputs = outputs.split(" ")
@@ -131,6 +134,7 @@ def chat(model_wrapper, args):
131134
print(" ".join(outputs[pre:]), flush=True)
132135

133136
conv.messages[-1][-1] = " ".join(outputs)
137+
add_bos = False
134138
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
135139

136140

@@ -154,15 +158,15 @@ def new_cache(self):
154158

155159
def __init__(self) -> None:
156160
self.kv_cache = None
161+
self.tot_seq_len = 0
157162
self.new_cache()
158163

159164
def forward(
160-
self, inputs: torch.Tensor, cur_pos: int, clear_cache: bool = False
165+
self, inputs: torch.Tensor
161166
) -> torch.Tensor:
162-
if clear_cache:
163-
self.new_cache()
164167
inputs = tvm.nd.array(inputs.numpy(), device=device)
165-
seq_len_shape = tvm.runtime.ShapeTuple([cur_pos])
168+
self.tot_seq_len+=inputs.shape[1]
169+
seq_len_shape = tvm.runtime.ShapeTuple([self.tot_seq_len])
166170
if inputs.shape[1] > 1:
167171
logits, kv_cache = vm["encoding"](
168172
inputs, seq_len_shape, self.kv_cache, const_params

web/gh-page-config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"dtype": "float32"
66
},
77
"wasmUrl": "dist/vicuna-7b-v1/vicuna-7b-v1_webgpu.wasm",
8-
"cacheUrl": "https://huggingface.co/mlc-ai/web-lm/resolve/main/vicuna-0b/",
8+
"cacheUrl": "https://huggingface.co/mlc-ai/web-lm/resolve/main/vicuna-7b-v1/",
99
"tokenizer": "dist/vicuna-7b-v1/tokenizer.model",
1010
"maxGenLength": 512,
1111
"meanGenLength": 128,

web/llm_chat.js

Lines changed: 112 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,57 @@ class Conversation {
3636
return ret;
3737
}
3838

39+
/**
40+
* Get prompt arrays that has not been fed as input
41+
*
42+
* @returns The prompt array.
43+
*/
44+
getPromptArrayUnproccessed() {
45+
if (this.seps.length == 0) {
46+
throw Error("Need seps to work")
47+
}
48+
if (this.messages.length < 3) {
49+
throw Error("needs to call getLastPromptArray for the first message");
50+
}
51+
let ret = [this.seps[this.seps.length - 1]];
52+
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
53+
const item = this.messages[i];
54+
const role = item[0];
55+
const message = item[1];
56+
if (message !== undefined && message != "") {
57+
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
58+
} else {
59+
ret.push(role + ":");
60+
}
61+
}
62+
return ret;
63+
64+
}
65+
66+
/**
67+
* Get last prompt array with prefix as system.
68+
*
69+
* @returns The prompt array.
70+
*/
71+
getLastPromptArray() {
72+
if (this.seps.length == 0) {
73+
throw Error("Need seps to work")
74+
}
75+
let ret = [this.system + this.seps[0]];
76+
77+
for (let i = this.messages.length - 2; i < this.messages.length; ++i) {
78+
const item = this.messages[i];
79+
const role = item[0];
80+
const message = item[1];
81+
if (message !== undefined && message != "") {
82+
ret.push(role + ": " + message + this.seps[i % this.seps.length]);
83+
} else {
84+
ret.push(role + ":");
85+
}
86+
}
87+
return ret;
88+
}
89+
3990
reset() {
4091
this.messages = [];
4192
}
@@ -52,12 +103,12 @@ class Conversation {
52103
function defaultConversation(maxWindowLength = 512) {
53104
return new Conversation({
54105
system: "A chat between a curious user and an artificial intelligence assistant. " +
55-
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
106+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
56107
roles: ["USER", "ASSISTANT"],
57108
maxWindowLength: maxWindowLength,
58109
messages: [],
59110
offset: 0,
60-
seps:[" ", "</s>"],
111+
seps: [" ", "</s>"],
61112
});
62113
};
63114

@@ -120,6 +171,9 @@ class LLMChatPipeline {
120171
this.kvCache = this.tvm.detachFromCurrentScope(this.tvm.makeTVMArray(kvList));
121172
// fill with pad token
122173
this.logitsOnCPU = undefined;
174+
175+
this.kvCacheLength = 0;
176+
this.clearCache = true
123177
}
124178

125179

@@ -167,7 +221,7 @@ class LLMChatPipeline {
167221
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu())
168222
);
169223
} else {
170-
if(logits.shape[0] != this.logitsOnCPU.shape[0]) {
224+
if (logits.shape[0] != this.logitsOnCPU.shape[0]) {
171225
throw Error("We expect the size of logits to remain unchanged");
172226
}
173227
}
@@ -183,35 +237,56 @@ class LLMChatPipeline {
183237
}
184238

185239
async getInputTokens() {
186-
const tokens = [this.bosTokenId];
187-
const prompts = this.conversation.getPromptArray();
240+
let tokens = [this.bosTokenId];
241+
let prompts = ""
242+
if (this.conversation.messages.length <= 2) {
243+
prompts = this.conversation.getPromptArray();
244+
} else {
245+
tokens.pop();
246+
prompts = this.conversation.getPromptArrayUnproccessed();
247+
}
188248
tokens.push(...await this.tokenizer.encodeIds(prompts[0]));
189-
190249
let ctxLength = tokens.length;
191-
const context = [];
250+
let context = [];
251+
let need_shift_window = false;
192252
for (let i = prompts.length - 1; i > 0; --i) {
193253
const encoded = this.tokenizer.encodeIds(prompts[i]);
194254
ctxLength += encoded.length;
195-
if (ctxLength + this.meanGenLength >= this.maxWindowLength && i + 2 < prompts.length) {
196-
this.logger("Shift window at " + i);
255+
if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) {
256+
need_shift_window = true;
197257
break;
198258
}
199259
context.unshift(encoded);
200260
}
201-
const followMessage = [];
202-
for (const ctx of context) {
203-
followMessage.push(...ctx);
261+
if (!need_shift_window) {
262+
for (const ctx of context) {
263+
tokens.push(...ctx);
264+
}
265+
return tokens;
204266
}
205-
206-
if (followMessage.length + tokens.length + this.meanGenLength >= this.maxWindowLength) {
207-
const maxMsgLen = this.maxWindowLength - tokens.length - this.meanGenLength;
208-
if (maxMsgLen < this.meanGenLength) {
209-
throw Error("Too small window config tokens.length=" + tokens.length);
267+
// need shift window and re-encode
268+
this.logger("need shift window")
269+
this.kvCacheLength = 0;
270+
this.clearCache = true;
271+
// abandon all tokens we collected
272+
tokens = [this.bosTokenId]
273+
let all_prompts = this.conversation.getPromptArray();
274+
tokens.push(...await this.tokenizer.encodeIds(all_prompts[0]));
275+
context = [];
276+
ctxLength = tokens.length;
277+
//only keep 10% of the window context
278+
const fill_factor = 0.1
279+
for (let i = all_prompts.length - 1; i > 0; --i) {
280+
const encoded = this.tokenizer.encodeIds(all_prompts[i]);
281+
ctxLength += encoded.length;
282+
if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) {
283+
break;
210284
}
211-
this.logger("Slice message " + followMessage.length + " to " + maxMsgLen);
212-
followMessage = followMessage.slice(followMessage.length - maxMsgLen);
285+
context.unshift(encoded);
286+
}
287+
for (const ctx of context) {
288+
tokens.push(...ctx);
213289
}
214-
tokens.push(...followMessage);
215290
if (tokens.length + this.meanGenLength >= this.maxWindowLength) {
216291
throw Error("Exceed max window length curr=" + tokens.length);
217292
}
@@ -235,16 +310,18 @@ class LLMChatPipeline {
235310
const inputTokenLength = tokens.length;
236311

237312
var outputPrompt = "";
238-
this.#clearKVCache();
313+
if (this.clearCache) {
314+
this.#clearKVCache();
315+
this.clearCache = false;
316+
}
239317
const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length);
240318
if (maxGenLen < this.meanGenLength) {
241319
throw Error("Too small window size config");
242320
}
243-
244-
for (let step = 0; step < maxGenLen; ++step) {
321+
let step = 0;
322+
for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) {
245323
this.tvm.beginScope();
246324
var inputData;
247-
248325
let tstart = performance.now();
249326
if (step == 0) {
250327
inputData = this.tvm.empty([1, tokens.length], "int32", this.device);
@@ -254,7 +331,7 @@ class LLMChatPipeline {
254331
inputData.copyFrom(tokens.slice(tokens.length - 1));
255332
}
256333
const logits = this.tvm.detachFromCurrentScope(
257-
this.#forward(inputData, inputTokenLength + step)
334+
this.#forward(inputData, this.kvCacheLength + inputTokenLength + step)
258335
);
259336
this.tvm.endScope();
260337

@@ -285,6 +362,7 @@ class LLMChatPipeline {
285362
callbackUpdateResponse(step, outputPrompt);
286363
}
287364
}
365+
this.kvCacheLength += tokens.length - 1;
288366
this.conversation.messages[this.conversation.messages.length - 1][1] = outputPrompt;
289367
return outputPrompt;
290368
}
@@ -358,12 +436,12 @@ class LLMChatInstance {
358436
this.logger = console.log;
359437
this.debugTest = false;
360438
}
361-
/**
362-
* Initialize TVM
363-
* @param wasmUrl URL to wasm source.
364-
* @param cacheUrl URL to NDArray cache.
365-
* @param logger Custom logger.
366-
*/
439+
/**
440+
* Initialize TVM
441+
* @param wasmUrl URL to wasm source.
442+
* @param cacheUrl URL to NDArray cache.
443+
* @param logger Custom logger.
444+
*/
367445
async #asyncInitTVM(wasmUrl, cacheUrl) {
368446
if (this.tvm !== undefined) {
369447
return;
@@ -395,7 +473,7 @@ class LLMChatInstance {
395473
this.reset();
396474
throw Error("This browser env do not support WebGPU");
397475
}
398-
} catch(err) {
476+
} catch (err) {
399477
this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString());
400478
console.log(err.stack);
401479
this.reset();
@@ -444,7 +522,7 @@ class LLMChatInstance {
444522
// initialize UX and tokenizer
445523
const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer);
446524
this.pipeline = this.tvm.withNewScope(() => {
447-
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
525+
return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config);
448526
});
449527
await this.pipeline.asyncLoadWebGPUPiplines();
450528
this.updateLastMessage("init", "All initialization finished.");
@@ -521,7 +599,7 @@ class LLMChatInstance {
521599

522600
try {
523601
await this.asyncInit();
524-
} catch(err) {
602+
} catch (err) {
525603
this.appendMessage("error", "Init error, " + err.toString());
526604
console.log(err.stack);
527605
this.reset();

web/local-config.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
"wasmUrl": "dist/vicuna-7b-v1/vicuna-7b-v1_webgpu.wasm",
88
"cacheUrl": "vicuna-7b-v1-params/",
99
"tokenizer": "dist/vicuna-7b-v1/tokenizer.model",
10-
"maxGenLength": 512,
11-
"meanGenLength": 128,
12-
"maxWindowLength": 1024
10+
"maxGenLength": 1024,
11+
"meanGenLength": 256,
12+
"maxWindowLength": 2048
1313
}

0 commit comments

Comments
 (0)