|
26 | 26 | #include "common.h" |
27 | 27 | #include "json.hpp" |
28 | 28 | #include "llama-cpp.h" |
| 29 | +#include "linenoise.h" |
29 | 30 |
|
30 | 31 | #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) |
31 | 32 | [[noreturn]] static void sigint_handler(int) { |
@@ -807,21 +808,22 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str |
807 | 808 | batch = llama_batch_get_one(&new_token_id, 1); |
808 | 809 | } |
809 | 810 |
|
| 811 | + printf("\033[0m"); |
810 | 812 | return 0; |
811 | 813 | } |
812 | 814 |
|
813 | | -static int read_user_input(std::string & user) { |
814 | | - std::getline(std::cin, user); |
| 815 | +static int read_user_input(std::string & user_input) { |
| 816 | + std::getline(std::cin, user_input); |
815 | 817 | if (std::cin.eof()) { |
816 | 818 | printf("\n"); |
817 | 819 | return 1; |
818 | 820 | } |
819 | 821 |
|
820 | | - if (user == "/bye") { |
| 822 | + if (user_input == "/bye") { |
821 | 823 | return 1; |
822 | 824 | } |
823 | 825 |
|
824 | | - if (user.empty()) { |
| 826 | + if (user_input.empty()) { |
825 | 827 | return 2; |
826 | 828 | } |
827 | 829 |
|
@@ -858,18 +860,47 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const |
858 | 860 | return 0; |
859 | 861 | } |
860 | 862 |
|
| 863 | +struct c_str { |
| 864 | + const char* data = nullptr; |
| 865 | + |
| 866 | + ~c_str() { |
| 867 | + free(const_cast<char*>(data)); |
| 868 | + } |
| 869 | +}; |
| 870 | + |
861 | 871 | // Helper function to handle user input |
862 | 872 | static int handle_user_input(std::string & user_input, const std::string & user) { |
863 | 873 | if (!user.empty()) { |
864 | 874 | user_input = user; |
865 | 875 | return 0; // No need for interactive input |
866 | 876 | } |
867 | 877 |
|
| 878 | + static const char* prompt_prefix = "> "; |
| 879 | +#ifdef WIN32 |
868 | 880 | printf( |
869 | 881 | "\r%*s" |
870 | | - "\r\033[32m> \033[0m", |
871 | | - get_terminal_width(), " "); |
| 882 | + "\r\033[0m%s", |
| 883 | + get_terminal_width(), " ", prompt_prefix); |
872 | 884 | return read_user_input(user_input); // Returns true if input ends the loop |
| 885 | +#else |
| 886 | + c_str line; |
| 887 | + line.data = linenoise(prompt_prefix); |
| 888 | + if (!line.data) { |
| 889 | + return 1; |
| 890 | + } |
| 891 | + |
| 892 | + user_input = line.data; |
| 893 | + if (user_input == "/bye") { |
| 894 | + return 1; |
| 895 | + } |
| 896 | + |
| 897 | + if (user_input.empty()) { |
| 898 | + return 2; |
| 899 | + } |
| 900 | + |
| 901 | + linenoiseHistoryAdd(line.data); |
| 902 | + return 0; |
| 903 | +#endif |
873 | 904 | } |
874 | 905 |
|
875 | 906 | static bool is_stdin_a_terminal() { |
@@ -961,13 +992,7 @@ static std::string read_pipe_data() { |
961 | 992 | } |
962 | 993 |
|
963 | 994 | static void ctrl_c_handling() { |
964 | | -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) |
965 | | - struct sigaction sigint_action; |
966 | | - sigint_action.sa_handler = sigint_handler; |
967 | | - sigemptyset(&sigint_action.sa_mask); |
968 | | - sigint_action.sa_flags = 0; |
969 | | - sigaction(SIGINT, &sigint_action, NULL); |
970 | | -#elif defined(_WIN32) |
| 995 | +#if defined(_WIN32) |
971 | 996 | auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
972 | 997 | return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; |
973 | 998 | }; |
|
0 commit comments