Skip to content

Commit d94638d

Browse files
committed
clear kvcache
1 parent c26dcbb commit d94638d

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/runner/LLM.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,14 @@ class LLM
609609

610610
final_out = tokenizer->Decode(token_ids);
611611

612+
for (size_t i = 0; i < _attr.axmodel_num; i++)
613+
{
614+
memset(llama_layers[i].layer.get_input(prefill_grpid, "K_cache").pVirAddr, 0, llama_layers[i].layer.get_input(prefill_grpid, "K_cache").nSize);
615+
memset(llama_layers[i].layer.get_input(prefill_grpid, "V_cache").pVirAddr, 0, llama_layers[i].layer.get_input(prefill_grpid, "V_cache").nSize);
616+
memset(llama_layers[i].layer.get_input(decode_grpid, "K_cache").pVirAddr, 0, llama_layers[i].layer.get_input(decode_grpid, "K_cache").nSize);
617+
memset(llama_layers[i].layer.get_input(decode_grpid, "V_cache").pVirAddr, 0, llama_layers[i].layer.get_input(decode_grpid, "V_cache").nSize);
618+
}
619+
612620
return final_out;
613621
}
614622
};

src/runner/ax_model_runner/ax_model_runner.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ax_runner_base
137137
}
138138
if (map_group_output_tensors.find(name) == map_group_output_tensors.end())
139139
{
140-
throw std::runtime_error("input tensor not found: " + name);
140+
throw std::runtime_error("output tensor not found: " + name);
141141
}
142142
return map_group_output_tensors[name][grpid];
143143
}

0 commit comments

Comments
 (0)