@@ -178,7 +178,7 @@ def forward(
178
178
cos_cached : relax .Expr ,
179
179
sin_cached : relax .Expr ,
180
180
all_seq_len_shape : relax .Expr ,
181
- past_key_value : Optional [ Tuple [relax .Expr ]] = None ,
181
+ past_key_value : Tuple [relax .Expr ],
182
182
attention_mask : Optional [relax .Expr ] = None ,
183
183
) -> Tuple [relax .Expr , Optional [relax .Expr ], Optional [Tuple [relax .Expr ]]]:
184
184
from tvm .relax .op import astype , matmul , maximum , permute_dims , reshape , squeeze
@@ -221,43 +221,43 @@ def forward(
221
221
[kv_states_shape [0 ], kv_seq_len , kv_states_shape [2 ], kv_states_shape [3 ]]
222
222
)
223
223
kv_cache_shape = R .shape ([kv_seq_len , kv_states_shape [2 ], kv_states_shape [3 ]])
224
- if past_key_value is not None :
225
- squeezed_key = nn .emit (squeeze (key_states , axis = 0 ))
226
- squeezed_value = nn .emit (squeeze (value_states , axis = 0 ))
227
- k_cache , v_cache = past_key_value
228
- f_kv_cache_append = relax .extern ("vm.builtin.attention_kv_cache_append" )
229
- k_cache = nn .emit (
230
- relax .Call (
231
- f_kv_cache_append ,
232
- args = [k_cache , squeezed_key ],
233
- sinfo_args = [relax .ObjectStructInfo ()],
234
- )
224
+
225
+ squeezed_key = nn .emit (squeeze (key_states , axis = 0 ))
226
+ squeezed_value = nn .emit (squeeze (value_states , axis = 0 ))
227
+ k_cache , v_cache = past_key_value
228
+ f_kv_cache_append = relax .extern ("vm.builtin.attention_kv_cache_append" )
229
+ k_cache = nn .emit (
230
+ relax .Call (
231
+ f_kv_cache_append ,
232
+ args = [k_cache , squeezed_key ],
233
+ sinfo_args = [relax .ObjectStructInfo ()],
235
234
)
236
- v_cache = nn . emit (
237
- relax . Call (
238
- f_kv_cache_append ,
239
- args = [ v_cache , squeezed_value ] ,
240
- sinfo_args = [ relax . ObjectStructInfo () ],
241
- )
235
+ )
236
+ v_cache = nn . emit (
237
+ relax . Call (
238
+ f_kv_cache_append ,
239
+ args = [ v_cache , squeezed_value ],
240
+ sinfo_args = [ relax . ObjectStructInfo ()],
242
241
)
243
- past_key_value = ( k_cache , v_cache )
244
- f_kv_cache_view = relax . extern ( "vm.builtin.attention_kv_cache_view" )
245
- k_cache = nn . emit (
246
- relax . Call (
247
- f_kv_cache_view ,
248
- args = [ k_cache , kv_cache_shape ] ,
249
- sinfo_args = [ R . Tensor ( kv_cache_shape , kv_states_dtype ) ],
250
- )
242
+ )
243
+ past_key_value = ( k_cache , v_cache )
244
+ f_kv_cache_view = relax . extern ( "vm.builtin.attention_kv_cache_view" )
245
+ k_cache = nn . emit (
246
+ relax . Call (
247
+ f_kv_cache_view ,
248
+ args = [ k_cache , kv_cache_shape ],
249
+ sinfo_args = [ R . Tensor ( kv_cache_shape , kv_states_dtype )],
251
250
)
252
- v_cache = nn . emit (
253
- relax . Call (
254
- f_kv_cache_view ,
255
- args = [ v_cache , kv_cache_shape ] ,
256
- sinfo_args = [ R . Tensor ( kv_cache_shape , kv_states_dtype ) ],
257
- )
251
+ )
252
+ v_cache = nn . emit (
253
+ relax . Call (
254
+ f_kv_cache_view ,
255
+ args = [ v_cache , kv_cache_shape ],
256
+ sinfo_args = [ R . Tensor ( kv_cache_shape , kv_states_dtype )],
258
257
)
259
- key_states = nn .emit (reshape (k_cache , kv_states_shape ))
260
- value_states = nn .emit (reshape (v_cache , kv_states_shape ))
258
+ )
259
+ key_states = nn .emit (reshape (k_cache , kv_states_shape ))
260
+ value_states = nn .emit (reshape (v_cache , kv_states_shape ))
261
261
262
262
query_states = nn .emit (permute_dims (query_states , [0 , 2 , 1 , 3 ]))
263
263
key_states = nn .emit (permute_dims (key_states , [0 , 2 , 1 , 3 ]))
@@ -333,8 +333,8 @@ def forward(
333
333
cos_cached : relax .Expr ,
334
334
sin_cached : relax .Expr ,
335
335
all_seq_len_shape : relax .Expr ,
336
+ past_key_value : Tuple [relax .Expr ],
336
337
attention_mask : Optional [relax .Expr ] = None ,
337
- past_key_value : Optional [Tuple [relax .Expr ]] = None ,
338
338
) -> Tuple [relax .Expr , Optional [Tuple [relax .Expr , relax .Expr ]]]:
339
339
residual = hidden_states
340
340
@@ -402,7 +402,7 @@ def forward(
402
402
cos_cached : relax .Expr ,
403
403
sin_cached : relax .Expr ,
404
404
all_seq_len_shape : relax .Expr ,
405
- past_key_values : Optional [ relax .Expr ] = None ,
405
+ past_key_values : relax .Expr ,
406
406
):
407
407
# retrieve input_ids
408
408
batch_size , seq_length = input_ids .struct_info .shape
@@ -421,11 +421,8 @@ def forward(
421
421
next_decoder_cache = ()
422
422
423
423
for idx , decoder_layer in enumerate (self .layers ):
424
- past_key_value = (
425
- (past_key_values [idx * 2 ], past_key_values [idx * 2 + 1 ])
426
- if past_key_values is not None
427
- else None
428
- )
424
+ assert past_key_values is not None
425
+ past_key_value = (past_key_values [idx * 2 ], past_key_values [idx * 2 + 1 ])
429
426
430
427
hidden_states , key_value_cache = decoder_layer (
431
428
hidden_states ,
@@ -459,7 +456,7 @@ def forward(
459
456
self ,
460
457
input_ids : relax .Expr ,
461
458
all_seq_len_shape : relax .Expr ,
462
- past_key_values : Optional [ List [ relax .Expr ]] = None ,
459
+ past_key_values : relax .Expr ,
463
460
):
464
461
hidden_states , key_value_cache = self .model (
465
462
input_ids = input_ids ,
@@ -543,20 +540,24 @@ def create_decoding_func(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
543
540
bb .update_func (gv , mod [gv ].with_attr ("num_input" , 3 ))
544
541
545
542
546
- def create_encoding_func_without_cache (bb : relax .BlockBuilder , config : LlamaConfig ) -> None :
547
- bsz = 1
548
- seq_len = tvm .tir .Var ("n" , "int64" )
549
-
550
- with bb .function ("encoding_without_cache" ):
551
- model = LlamaForCausalLM (config )
552
- input_ids = nn .Placeholder ((bsz , seq_len ), dtype = "int32" , name = "input_ids" )
553
- all_seq_len_shape = relax .Var ("all_seq_len" , relax .ShapeStructInfo ((seq_len ,)))
543
+ def create_kv_cache_func (bb : relax .BlockBuilder , config : LlamaConfig ) -> None :
544
+ init_shape = relax .ShapeExpr (
545
+ (1 , config .num_attention_heads , config .hidden_size // config .num_attention_heads )
546
+ )
547
+ with bb .function ("create_kv_cache" , []):
554
548
with bb .dataflow ():
555
- logits , _ = model (input_ids , all_seq_len_shape )
556
- params = [input_ids , all_seq_len_shape ] + model .parameters ()
557
- gv = bb .emit_output (logits )
558
- bb .emit_func_output (gv , params )
559
-
560
- mod = bb .get ()
561
- gv = mod .get_global_var ("encoding_without_cache" )
562
- bb .update_func (gv , mod [gv ].with_attr ("num_input" , 2 ))
549
+ zeros = bb .emit (relax .op .zeros (init_shape , "float32" ))
550
+ caches = []
551
+ f_kv_cache_create = relax .extern ("vm.builtin.attention_kv_cache_create" )
552
+ for _ in range (config .num_hidden_layers * 2 ):
553
+ caches .append (
554
+ bb .emit (
555
+ relax .Call (
556
+ f_kv_cache_create ,
557
+ args = [zeros , init_shape , relax .PrimValue (0 )],
558
+ sinfo_args = [relax .ObjectStructInfo ()],
559
+ )
560
+ )
561
+ )
562
+ gv = bb .emit_output (caches )
563
+ bb .emit_func_output (gv )
0 commit comments