|
25 | 25 |
|
26 | 26 | #include "common.h" |
27 | 27 | #include "json.hpp" |
| 28 | +#include "linenoise.h" |
28 | 29 | #include "llama-cpp.h" |
29 | 30 |
|
30 | 31 | #if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) |
@@ -807,21 +808,22 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str |
807 | 808 | batch = llama_batch_get_one(&new_token_id, 1); |
808 | 809 | } |
809 | 810 |
|
| 811 | + printf("\033[0m"); |
810 | 812 | return 0; |
811 | 813 | } |
812 | 814 |
|
813 | | -static int read_user_input(std::string & user) { |
814 | | - std::getline(std::cin, user); |
| 815 | +static int read_user_input(std::string & user_input) { |
| 816 | + std::getline(std::cin, user_input); |
815 | 817 | if (std::cin.eof()) { |
816 | 818 | printf("\n"); |
817 | 819 | return 1; |
818 | 820 | } |
819 | 821 |
|
820 | | - if (user == "/bye") { |
| 822 | + if (user_input == "/bye") { |
821 | 823 | return 1; |
822 | 824 | } |
823 | 825 |
|
824 | | - if (user.empty()) { |
| 826 | + if (user_input.empty()) { |
825 | 827 | return 2; |
826 | 828 | } |
827 | 829 |
|
@@ -858,18 +860,45 @@ static int apply_chat_template_with_error_handling(LlamaData & llama_data, const |
858 | 860 | return 0; |
859 | 861 | } |
860 | 862 |
|
| 863 | +struct c_str { |
| 864 | + const char * data = nullptr; |
| 865 | + |
| 866 | + ~c_str() { free(const_cast<char *>(data)); } |
| 867 | +}; |
| 868 | + |
861 | 869 | // Helper function to handle user input |
862 | 870 | static int handle_user_input(std::string & user_input, const std::string & user) { |
863 | 871 | if (!user.empty()) { |
864 | 872 | user_input = user; |
865 | 873 | return 0; // No need for interactive input |
866 | 874 | } |
867 | 875 |
|
| 876 | + static const char * prompt_prefix = "> "; |
| 877 | +#ifdef WIN32 |
868 | 878 | printf( |
869 | 879 | "\r%*s" |
870 | | - "\r\033[32m> \033[0m", |
871 | | - get_terminal_width(), " "); |
| 880 | + "\r\033[0m%s", |
| 881 | + get_terminal_width(), " ", prompt_prefix); |
872 | 882 | return read_user_input(user_input); // Returns true if input ends the loop |
| 883 | +#else |
| 884 | + c_str line; |
| 885 | + line.data = linenoise(prompt_prefix); |
| 886 | + if (!line.data) { |
| 887 | + return 1; |
| 888 | + } |
| 889 | + |
| 890 | + user_input = line.data; |
| 891 | + if (user_input == "/bye") { |
| 892 | + return 1; |
| 893 | + } |
| 894 | + |
| 895 | + if (user_input.empty()) { |
| 896 | + return 2; |
| 897 | + } |
| 898 | + |
| 899 | + linenoiseHistoryAdd(line.data); |
| 900 | + return 0; |
| 901 | +#endif |
873 | 902 | } |
874 | 903 |
|
875 | 904 | static bool is_stdin_a_terminal() { |
|
0 commit comments