Skip to content

Commit d7a4f3e

Browse files
committed
main : add special commands
1 parent c02e5ab commit d7a4f3e

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

examples/main/main.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
#pragma warning(disable: 4244 4267) // possible loss of data
3232
#endif
3333

34+
static const std::string CMD_READFILE = "/readfile";
35+
static const std::string CMD_SAVE_SESS = "/savesess";
36+
static const std::string CMD_LOAD_SESS = "/loadsess";
37+
3438
static llama_context ** g_ctx;
3539
static llama_model ** g_model;
3640
static common_sampler ** g_smpl;
@@ -851,6 +855,43 @@ int main(int argc, char ** argv) {
851855

852856
LOG_DBG("buffer: '%s'\n", buffer.c_str());
853857

858+
// check for special commands
859+
if (buffer.rfind(CMD_READFILE, 0) == 0) {
860+
const std::string filename = string_strip(buffer.substr(CMD_READFILE.length()));
861+
LOG_DBG("reading file: '%s'\n", filename.c_str());
862+
std::ifstream text_file(filename);
863+
if (!text_file) {
864+
LOG("failed to open file '%s'\n", filename.c_str());
865+
continue;
866+
}
867+
std::stringstream tmp;
868+
tmp << text_file.rdbuf();
869+
buffer = tmp.str();
870+
LOG("%s\n", buffer.c_str());
871+
} else if (buffer.rfind(CMD_SAVE_SESS, 0) == 0) {
872+
const std::string filename = string_strip(buffer.substr(CMD_SAVE_SESS.length()));
873+
LOG("save session file: '%s'\n", filename.c_str());
874+
size_t res = llama_state_save_file(ctx, filename.c_str(), embd_inp.data(), n_past);
875+
if (res == 0) {
876+
LOG("failed to save session file '%s'\n", filename.c_str());
877+
}
878+
continue;
879+
} else if (buffer.rfind(CMD_LOAD_SESS, 0) == 0) {
880+
const std::string filename = string_strip(buffer.substr(CMD_LOAD_SESS.length()));
881+
LOG("load session file: '%s'\n", filename.c_str());
882+
std::vector<llama_token> sess_tokens;
883+
sess_tokens.resize(n_ctx);
884+
size_t n_loaded_tokens;
885+
size_t res = llama_state_load_file(ctx, filename.c_str(), sess_tokens.data(), sess_tokens.size(), &n_loaded_tokens);
886+
if (res == 0) {
887+
LOG("failed to load session file '%s'\n", filename.c_str());
888+
} else {
889+
n_past = n_loaded_tokens;
890+
LOG("loaded %zu tokens from session file '%s'\n", n_loaded_tokens, filename.c_str());
891+
}
892+
continue;
893+
}
894+
854895
const size_t original_size = embd_inp.size();
855896

856897
if (params.escape) {

0 commit comments

Comments
 (0)