Skip to content

Commit cb76f41

Browse files
authored
KV cache creation as a separate func (#53)
This PR changes the creation of KV cache so that it becomes a VM function that will be invoked before generation.
1 parent 35c42f2 commit cb76f41

File tree

4 files changed

+65
-90
lines changed

4 files changed

+65
-90
lines changed

build.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def get_models(config, model):
8181
if "vicuna" in model or "llama" in model:
8282
bb = relax.BlockBuilder()
8383
llama.create_encoding_func(bb, config)
84-
llama.create_encoding_func_without_cache(bb, config)
8584
llama.create_decoding_func(bb, config)
85+
llama.create_kv_cache_func(bb, config)
8686
mod = bb.get()
8787

8888
for gv in mod.functions:
@@ -121,7 +121,7 @@ def mod_transform_before_build(
121121
mod: tvm.IRModule, model_params: List[tvm.nd.NDArray], args: Dict
122122
) -> tvm.IRModule:
123123
"""First-stage: Legalize ops and trace"""
124-
model_names = ["encoding", "decoding", "encoding_without_cache"]
124+
model_names = ["encoding", "decoding", "create_kv_cache"]
125125

126126
mod = web_llm.transform.GroupQuantize(group_size=32, sym=False)(mod)
127127
mod = web_llm.transform.FuseTransposeMatmul()(mod)

chat.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,10 @@ def get_tvm_model(args):
145145
vm = relax.VirtualMachine(ex, device)
146146

147147
class Model:
148-
def new_cache(self):
149-
fcreate_cache = tvm.get_global_func("vm.builtin.attention_kv_cache_create")
150-
self.kv_cache = []
151-
for i in range(64): # num_layer
152-
kv_cache = fcreate_cache(
153-
tvm.nd.empty((1, 32, 128), device=device, dtype="float32"),
154-
tvm.runtime.ShapeTuple([32, 32, 128]),
155-
0
156-
)
157-
self.kv_cache.append(kv_cache)
158148

159149
def __init__(self) -> None:
160-
self.kv_cache = None
161150
self.tot_seq_len = 0
162-
self.new_cache()
151+
self.kv_cache = vm["create_kv_cache"]()
163152

164153
def forward(
165154
self, inputs: torch.Tensor

web/llm_chat.js

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -145,30 +145,16 @@ class LLMChatPipeline {
145145
this.decoding = this.tvm.detachFromCurrentScope(
146146
this.vm.getFunction("decoding")
147147
);
148-
this.encodingWithoutCache = this.tvm.detachFromCurrentScope(
149-
this.vm.getFunction("encoding_without_cache")
150-
);
151148
this.params = this.tvm.detachFromCurrentScope(
152149
this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize)
153150
);
154-
const fcreateCache = this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_create");
151+
const fcreateCache = this.vm.getFunction("create_kv_cache");
155152
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
156153
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
157154
);
158155

159156
// use extern config for now
160-
// move to kv generation vm function
161-
const kvList = [];
162-
const kvConfig = config.kvConfig;
163-
for (let i = 0; i < kvConfig.numLayers; ++i) {
164-
const item = fcreateCache(
165-
this.tvm.empty(kvConfig.shape, kvConfig.dtype, this.device),
166-
this.tvm.makeShapeTuple(kvConfig.shape),
167-
this.tvm.scalar(0, "int")
168-
);
169-
kvList.push(item);
170-
}
171-
this.kvCache = this.tvm.detachFromCurrentScope(this.tvm.makeTVMArray(kvList));
157+
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
172158
// fill with pad token
173159
this.logitsOnCPU = undefined;
174160

@@ -180,7 +166,6 @@ class LLMChatPipeline {
180166
dispose() {
181167
// note: tvm instance is not owned by this class
182168
this.params.dispose();
183-
this.encodingWithoutCache.dispose();
184169
this.decoding.dispose();
185170
this.encoding.dispose();
186171
this.vm.dispose();
@@ -368,7 +353,7 @@ class LLMChatPipeline {
368353
}
369354

370355
async evaluate() {
371-
// run a canonicla evaluateion fo the flow
356+
// run a canonical evaluation of the flow
372357
this.#clearKVCache();
373358
const testPrompt = "The capital of Canada is";
374359
const ids = await this.tokenizer.encodeIds(testPrompt);

web_llm/relax_model/llama.py

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def forward(
178178
cos_cached: relax.Expr,
179179
sin_cached: relax.Expr,
180180
all_seq_len_shape: relax.Expr,
181-
past_key_value: Optional[Tuple[relax.Expr]] = None,
181+
past_key_value: Tuple[relax.Expr],
182182
attention_mask: Optional[relax.Expr] = None,
183183
) -> Tuple[relax.Expr, Optional[relax.Expr], Optional[Tuple[relax.Expr]]]:
184184
from tvm.relax.op import astype, matmul, maximum, permute_dims, reshape, squeeze
@@ -221,43 +221,43 @@ def forward(
221221
[kv_states_shape[0], kv_seq_len, kv_states_shape[2], kv_states_shape[3]]
222222
)
223223
kv_cache_shape = R.shape([kv_seq_len, kv_states_shape[2], kv_states_shape[3]])
224-
if past_key_value is not None:
225-
squeezed_key = nn.emit(squeeze(key_states, axis=0))
226-
squeezed_value = nn.emit(squeeze(value_states, axis=0))
227-
k_cache, v_cache = past_key_value
228-
f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append")
229-
k_cache = nn.emit(
230-
relax.Call(
231-
f_kv_cache_append,
232-
args=[k_cache, squeezed_key],
233-
sinfo_args=[relax.ObjectStructInfo()],
234-
)
224+
225+
squeezed_key = nn.emit(squeeze(key_states, axis=0))
226+
squeezed_value = nn.emit(squeeze(value_states, axis=0))
227+
k_cache, v_cache = past_key_value
228+
f_kv_cache_append = relax.extern("vm.builtin.attention_kv_cache_append")
229+
k_cache = nn.emit(
230+
relax.Call(
231+
f_kv_cache_append,
232+
args=[k_cache, squeezed_key],
233+
sinfo_args=[relax.ObjectStructInfo()],
235234
)
236-
v_cache = nn.emit(
237-
relax.Call(
238-
f_kv_cache_append,
239-
args=[v_cache, squeezed_value],
240-
sinfo_args=[relax.ObjectStructInfo()],
241-
)
235+
)
236+
v_cache = nn.emit(
237+
relax.Call(
238+
f_kv_cache_append,
239+
args=[v_cache, squeezed_value],
240+
sinfo_args=[relax.ObjectStructInfo()],
242241
)
243-
past_key_value = (k_cache, v_cache)
244-
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
245-
k_cache = nn.emit(
246-
relax.Call(
247-
f_kv_cache_view,
248-
args=[k_cache, kv_cache_shape],
249-
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
250-
)
242+
)
243+
past_key_value = (k_cache, v_cache)
244+
f_kv_cache_view = relax.extern("vm.builtin.attention_kv_cache_view")
245+
k_cache = nn.emit(
246+
relax.Call(
247+
f_kv_cache_view,
248+
args=[k_cache, kv_cache_shape],
249+
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
251250
)
252-
v_cache = nn.emit(
253-
relax.Call(
254-
f_kv_cache_view,
255-
args=[v_cache, kv_cache_shape],
256-
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
257-
)
251+
)
252+
v_cache = nn.emit(
253+
relax.Call(
254+
f_kv_cache_view,
255+
args=[v_cache, kv_cache_shape],
256+
sinfo_args=[R.Tensor(kv_cache_shape, kv_states_dtype)],
258257
)
259-
key_states = nn.emit(reshape(k_cache, kv_states_shape))
260-
value_states = nn.emit(reshape(v_cache, kv_states_shape))
258+
)
259+
key_states = nn.emit(reshape(k_cache, kv_states_shape))
260+
value_states = nn.emit(reshape(v_cache, kv_states_shape))
261261

262262
query_states = nn.emit(permute_dims(query_states, [0, 2, 1, 3]))
263263
key_states = nn.emit(permute_dims(key_states, [0, 2, 1, 3]))
@@ -333,8 +333,8 @@ def forward(
333333
cos_cached: relax.Expr,
334334
sin_cached: relax.Expr,
335335
all_seq_len_shape: relax.Expr,
336+
past_key_value: Tuple[relax.Expr],
336337
attention_mask: Optional[relax.Expr] = None,
337-
past_key_value: Optional[Tuple[relax.Expr]] = None,
338338
) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]:
339339
residual = hidden_states
340340

@@ -402,7 +402,7 @@ def forward(
402402
cos_cached: relax.Expr,
403403
sin_cached: relax.Expr,
404404
all_seq_len_shape: relax.Expr,
405-
past_key_values: Optional[relax.Expr] = None,
405+
past_key_values: relax.Expr,
406406
):
407407
# retrieve input_ids
408408
batch_size, seq_length = input_ids.struct_info.shape
@@ -421,11 +421,8 @@ def forward(
421421
next_decoder_cache = ()
422422

423423
for idx, decoder_layer in enumerate(self.layers):
424-
past_key_value = (
425-
(past_key_values[idx * 2], past_key_values[idx * 2 + 1])
426-
if past_key_values is not None
427-
else None
428-
)
424+
assert past_key_values is not None
425+
past_key_value = (past_key_values[idx * 2], past_key_values[idx * 2 + 1])
429426

430427
hidden_states, key_value_cache = decoder_layer(
431428
hidden_states,
@@ -459,7 +456,7 @@ def forward(
459456
self,
460457
input_ids: relax.Expr,
461458
all_seq_len_shape: relax.Expr,
462-
past_key_values: Optional[List[relax.Expr]] = None,
459+
past_key_values: relax.Expr,
463460
):
464461
hidden_states, key_value_cache = self.model(
465462
input_ids=input_ids,
@@ -543,20 +540,24 @@ def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
543540
bb.update_func(gv, mod[gv].with_attr("num_input", 3))
544541

545542

546-
def create_encoding_func_without_cache(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
547-
bsz = 1
548-
seq_len = tvm.tir.Var("n", "int64")
549-
550-
with bb.function("encoding_without_cache"):
551-
model = LlamaForCausalLM(config)
552-
input_ids = nn.Placeholder((bsz, seq_len), dtype="int32", name="input_ids")
553-
all_seq_len_shape = relax.Var("all_seq_len", relax.ShapeStructInfo((seq_len,)))
543+
def create_kv_cache_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
544+
init_shape = relax.ShapeExpr(
545+
(1, config.num_attention_heads, config.hidden_size // config.num_attention_heads)
546+
)
547+
with bb.function("create_kv_cache", []):
554548
with bb.dataflow():
555-
logits, _ = model(input_ids, all_seq_len_shape)
556-
params = [input_ids, all_seq_len_shape] + model.parameters()
557-
gv = bb.emit_output(logits)
558-
bb.emit_func_output(gv, params)
559-
560-
mod = bb.get()
561-
gv = mod.get_global_var("encoding_without_cache")
562-
bb.update_func(gv, mod[gv].with_attr("num_input", 2))
549+
zeros = bb.emit(relax.op.zeros(init_shape, "float32"))
550+
caches = []
551+
f_kv_cache_create = relax.extern("vm.builtin.attention_kv_cache_create")
552+
for _ in range(config.num_hidden_layers * 2):
553+
caches.append(
554+
bb.emit(
555+
relax.Call(
556+
f_kv_cache_create,
557+
args=[zeros, init_shape, relax.PrimValue(0)],
558+
sinfo_args=[relax.ObjectStructInfo()],
559+
)
560+
)
561+
)
562+
gv = bb.emit_output(caches)
563+
bb.emit_func_output(gv)

0 commit comments

Comments
 (0)