|
11 | 11 | # include <curl/curl.h> |
12 | 12 | #endif |
13 | 13 |
|
| 14 | +#include <signal.h> |
| 15 | + |
14 | 16 | #include <climits> |
15 | 17 | #include <cstdarg> |
16 | 18 | #include <cstdio> |
|
25 | 27 | #include "json.hpp" |
26 | 28 | #include "llama-cpp.h" |
27 | 29 |
|
| 30 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) |
| 31 | +static void sigint_handler(int signo) { |
| 32 | + printf("\n"); |
| 33 | + exit(0); // not ideal, but it's the only way to guarantee exit in all cases |
| 34 | +} |
| 35 | +#endif |
| 36 | + |
28 | 37 | GGML_ATTRIBUTE_FORMAT(1, 2) |
29 | 38 | static std::string fmt(const char * fmt, ...) { |
30 | 39 | va_list ap; |
@@ -801,7 +810,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str |
801 | 810 |
|
802 | 811 | static int read_user_input(std::string & user) { |
803 | 812 | std::getline(std::cin, user); |
804 | | - return user.empty(); // Should have data in happy path |
| 813 | + if (std::cin.eof()) { |
| 814 | + printf("\n"); |
| 815 | + return 1; |
| 816 | + } |
| 817 | + |
| 818 | + if (user == "/bye") { |
| 819 | + return 1; |
| 820 | + } |
| 821 | + |
| 822 | + if (user.empty()) { |
| 823 | + return 2; |
| 824 | + } |
| 825 | + |
| 826 | + return 0; // Should have data in happy path |
805 | 827 | } |
806 | 828 |
|
807 | 829 | // Function to generate a response based on the prompt |
@@ -868,15 +890,34 @@ static bool is_stdout_a_terminal() { |
868 | 890 | #endif |
869 | 891 | } |
870 | 892 |
|
871 | | -// Function to tokenize the prompt |
| 893 | +// Function to handle user input |
| 894 | +static int get_user_input(std::string & user_input, const std::string & user) { |
| 895 | + while (true) { |
| 896 | + const int ret = handle_user_input(user_input, user); |
| 897 | + if (ret == 1) { |
| 898 | + return 1; |
| 899 | + } |
| 900 | + |
| 901 | + if (ret == 2) { |
| 902 | + continue; |
| 903 | + } |
| 904 | + |
| 905 | + break; |
| 906 | + } |
| 907 | + |
| 908 | + return 0; |
| 909 | +} |
| 910 | + |
| 911 | +// Main chat loop function |
872 | 912 | static int chat_loop(LlamaData & llama_data, const std::string & user) { |
873 | 913 | int prev_len = 0; |
874 | 914 | llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); |
875 | 915 | static const bool stdout_a_terminal = is_stdout_a_terminal(); |
876 | 916 | while (true) { |
877 | 917 | // Get user input |
878 | 918 | std::string user_input; |
879 | | - while (handle_user_input(user_input, user)) { |
| 919 | + if (get_user_input(user_input, user) == 1) { |
| 920 | + return 0; |
880 | 921 | } |
881 | 922 |
|
882 | 923 | add_message("user", user.empty() ? user_input : user, llama_data); |
@@ -917,7 +958,23 @@ static std::string read_pipe_data() { |
917 | 958 | return result.str(); |
918 | 959 | } |
919 | 960 |
|
| 961 | +static void ctrl_c_handling() { |
| 962 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) |
| 963 | + struct sigaction sigint_action; |
| 964 | + sigint_action.sa_handler = sigint_handler; |
| 965 | + sigemptyset(&sigint_action.sa_mask); |
| 966 | + sigint_action.sa_flags = 0; |
| 967 | + sigaction(SIGINT, &sigint_action, NULL); |
| 968 | +#elif defined(_WIN32) |
| 969 | + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
| 970 | + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; |
| 971 | + }; |
| 972 | + SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); |
| 973 | +#endif |
| 974 | +} |
| 975 | + |
920 | 976 | int main(int argc, const char ** argv) { |
| 977 | + ctrl_c_handling(); |
921 | 978 | Opt opt; |
922 | 979 | const int ret = opt.init(argc, argv); |
923 | 980 | if (ret == 2) { |
|
0 commit comments