|
2 | 2 | #include "error.h" |
3 | 3 | #include <algorithm> |
4 | 4 | #include <cstdio> |
| 5 | +#include <filesystem> |
5 | 6 |
|
6 | 7 | namespace agent_cpp { |
7 | 8 |
|
@@ -144,4 +145,65 @@ Agent::run_loop(std::vector<common_chat_msg>& messages, |
144 | 145 | } |
145 | 146 | } |
146 | 147 |
|
| 148 | +std::vector<llama_token> |
| 149 | +Agent::build_prompt_tokens() |
| 150 | +{ |
| 151 | + if (!model) { |
| 152 | + return {}; |
| 153 | + } |
| 154 | + |
| 155 | + std::vector<common_chat_msg> system_messages; |
| 156 | + if (!instructions.empty()) { |
| 157 | + common_chat_msg system_msg; |
| 158 | + system_msg.role = "system"; |
| 159 | + system_msg.content = instructions; |
| 160 | + system_messages.push_back(system_msg); |
| 161 | + } |
| 162 | + |
| 163 | + std::vector<common_chat_tool> tool_definitions = get_tool_definitions(); |
| 164 | + |
| 165 | + common_chat_templates_inputs inputs; |
| 166 | + inputs.messages = system_messages; |
| 167 | + inputs.tools = tool_definitions; |
| 168 | + inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; |
| 169 | + inputs.add_generation_prompt = false; |
| 170 | + inputs.enable_thinking = false; |
| 171 | + |
| 172 | + auto params = common_chat_templates_apply(model->get_templates(), inputs); |
| 173 | + |
| 174 | + return model->tokenize(params.prompt); |
| 175 | +} |
| 176 | + |
| 177 | +bool |
| 178 | +Agent::load_or_create_cache(const std::string& cache_path) |
| 179 | +{ |
| 180 | + if (!model) { |
| 181 | + return false; |
| 182 | + } |
| 183 | + |
| 184 | + if (std::filesystem::exists(cache_path)) { |
| 185 | + auto cached_tokens = model->load_cache(cache_path); |
| 186 | + if (!cached_tokens.empty()) { |
| 187 | + printf("Loaded prompt cache from '%s' (%zu tokens)\n", |
| 188 | + cache_path.c_str(), |
| 189 | + cached_tokens.size()); |
| 190 | + return true; |
| 191 | + } |
| 192 | + } |
| 193 | + |
| 194 | + auto prompt_tokens = build_prompt_tokens(); |
| 195 | + if (prompt_tokens.empty()) { |
| 196 | + return true; |
| 197 | + } |
| 198 | + |
| 199 | + printf("Creating prompt cache at '%s' (%zu tokens)\n", |
| 200 | + cache_path.c_str(), |
| 201 | + prompt_tokens.size()); |
| 202 | + |
| 203 | + // warms the KV cache |
| 204 | + model->generate_from_tokens(prompt_tokens); |
| 205 | + |
| 206 | + return model->save_cache(cache_path); |
| 207 | +} |
| 208 | + |
147 | 209 | } // namespace agent_cpp |
0 commit comments