diff --git a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py index a79082da..559a375b 100644 --- a/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py +++ b/tico/quantization/wrapq/examples/llama/quantize_decoder_layer_decode.py @@ -183,13 +183,11 @@ def make_random_decode_batch(): ( x_ex, # hidden_states mask_ex, # attention_mask - None, # position_ids (unused) - (past_k_ex, past_v_ex), # past_key_value - False, # output_attentions - True, # use_cache - None, # cache_position (unused) - (cos_ex, sin_ex), # position_embeddings ), + { + "past_key_value": (past_k_ex, past_v_ex), + "position_embeddings": (cos_ex, sin_ex), + }, ) cm.save(save_path)