|
11 | 11 | # include <curl/curl.h> |
12 | 12 | #endif |
13 | 13 |
|
| 14 | +#include <signal.h> |
| 15 | + |
14 | 16 | #include <climits> |
15 | 17 | #include <cstdarg> |
16 | 18 | #include <cstdio> |
|
25 | 27 | #include "json.hpp" |
26 | 28 | #include "llama-cpp.h" |
27 | 29 |
|
| 30 | +static bool sigint = false; |
| 31 | + |
| 32 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) |
| 33 | +static void sigint_handler(int signo) { |
| 34 | + if (signo == SIGINT) { |
| 35 | + sigint = true; |
| 36 | + } |
| 37 | +} |
| 38 | +#endif |
| 39 | + |
28 | 40 | GGML_ATTRIBUTE_FORMAT(1, 2) |
29 | 41 | static std::string fmt(const char * fmt, ...) { |
30 | 42 | va_list ap; |
@@ -801,7 +813,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str |
801 | 813 |
|
802 | 814 | static int read_user_input(std::string & user) { |
803 | 815 | std::getline(std::cin, user); |
804 | | - return user.empty(); // Should have data in happy path |
| 816 | + if (sigint || std::cin.eof()) { |
| 817 | + printf("\n"); |
| 818 | + return 1; |
| 819 | + } |
| 820 | + |
| 821 | + if (user == "/bye") { |
| 822 | + return 1; |
| 823 | + } |
| 824 | + |
| 825 | + if (user.empty()) { |
| 826 | + return 2; |
| 827 | + } |
| 828 | + |
| 829 | + return 0; // Should have data in happy path |
805 | 830 | } |
806 | 831 |
|
807 | 832 | // Function to generate a response based on the prompt |
@@ -868,15 +893,34 @@ static bool is_stdout_a_terminal() { |
868 | 893 | #endif |
869 | 894 | } |
870 | 895 |
|
871 | | -// Function to tokenize the prompt |
| 896 | +// Function to handle user input |
| 897 | +static int get_user_input(std::string & user_input, const std::string & user) { |
| 898 | + while (true) { |
| 899 | + const int ret = handle_user_input(user_input, user); |
| 900 | + if (ret == 1) { |
| 901 | + return 1; |
| 902 | + } |
| 903 | + |
| 904 | + if (ret == 2) { |
| 905 | + continue; |
| 906 | + } |
| 907 | + |
| 908 | + break; |
| 909 | + } |
| 910 | + |
| 911 | + return 0; |
| 912 | +} |
| 913 | + |
| 914 | +// Main chat loop function |
872 | 915 | static int chat_loop(LlamaData & llama_data, const std::string & user) { |
873 | 916 | int prev_len = 0; |
874 | 917 | llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); |
875 | 918 | static const bool stdout_a_terminal = is_stdout_a_terminal(); |
876 | 919 | while (true) { |
877 | 920 | // Get user input |
878 | 921 | std::string user_input; |
879 | | - while (handle_user_input(user_input, user)) { |
| 922 | + if (get_user_input(user_input, user) == 1) { |
| 923 | + return 0; |
880 | 924 | } |
881 | 925 |
|
882 | 926 | add_message("user", user.empty() ? user_input : user, llama_data); |
@@ -917,7 +961,23 @@ static std::string read_pipe_data() { |
917 | 961 | return result.str(); |
918 | 962 | } |
919 | 963 |
|
| 964 | +static void ctrl_c_handling() { |
| 965 | +#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) |
| 966 | + struct sigaction sigint_action; |
| 967 | + sigint_action.sa_handler = sigint_handler; |
| 968 | + sigemptyset(&sigint_action.sa_mask); |
| 969 | + sigint_action.sa_flags = 0; |
| 970 | + sigaction(SIGINT, &sigint_action, NULL); |
| 971 | +#elif defined(_WIN32) |
| 972 | + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
| 973 | + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; |
| 974 | + }; |
| 975 | + SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); |
| 976 | +#endif |
| 977 | +} |
| 978 | + |
920 | 979 | int main(int argc, const char ** argv) { |
| 980 | + ctrl_c_handling(); |
921 | 981 | Opt opt; |
922 | 982 | const int ret = opt.init(argc, argv); |
923 | 983 | if (ret == 2) { |
|
0 commit comments