Skip to content

Commit a85db13

Browse files
committed
Enhance user input handling for llama-run
The main motivation for this change is it was not handing ctrl-c/ctrl-d correctly. Modify `read_user_input` to handle EOF, "/bye" command, and empty input cases. Introduce `get_user_input` function to manage user input loop and handle different return cases. Signed-off-by: Eric Curtin <[email protected]>
1 parent 99a3755 commit a85db13

File tree

1 file changed

+63
-3
lines changed

1 file changed

+63
-3
lines changed

examples/run/run.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# include <curl/curl.h>
1212
#endif
1313

14+
#include <signal.h>
15+
1416
#include <climits>
1517
#include <cstdarg>
1618
#include <cstdio>
@@ -25,6 +27,16 @@
2527
#include "json.hpp"
2628
#include "llama-cpp.h"
2729

30+
static bool sigint = false;
31+
32+
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32)
33+
static void sigint_handler(int signo) {
34+
if (signo == SIGINT) {
35+
sigint = true;
36+
}
37+
}
38+
#endif
39+
2840
GGML_ATTRIBUTE_FORMAT(1, 2)
2941
static std::string fmt(const char * fmt, ...) {
3042
va_list ap;
@@ -801,7 +813,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
801813

802814
static int read_user_input(std::string & user) {
803815
std::getline(std::cin, user);
804-
return user.empty(); // Should have data in happy path
816+
if (sigint || std::cin.eof()) {
817+
printf("\n");
818+
return 1;
819+
}
820+
821+
if (user == "/bye") {
822+
return 1;
823+
}
824+
825+
if (user.empty()) {
826+
return 2;
827+
}
828+
829+
return 0; // Should have data in happy path
805830
}
806831

807832
// Function to generate a response based on the prompt
@@ -868,15 +893,34 @@ static bool is_stdout_a_terminal() {
868893
#endif
869894
}
870895

871-
// Function to tokenize the prompt
896+
// Function to handle user input
897+
static int get_user_input(std::string & user_input, const std::string & user) {
898+
while (true) {
899+
const int ret = handle_user_input(user_input, user);
900+
if (ret == 1) {
901+
return 1;
902+
}
903+
904+
if (ret == 2) {
905+
continue;
906+
}
907+
908+
break;
909+
}
910+
911+
return 0;
912+
}
913+
914+
// Main chat loop function
872915
static int chat_loop(LlamaData & llama_data, const std::string & user) {
873916
int prev_len = 0;
874917
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
875918
static const bool stdout_a_terminal = is_stdout_a_terminal();
876919
while (true) {
877920
// Get user input
878921
std::string user_input;
879-
while (handle_user_input(user_input, user)) {
922+
if (get_user_input(user_input, user) == 1) {
923+
return 0;
880924
}
881925

882926
add_message("user", user.empty() ? user_input : user, llama_data);
@@ -917,7 +961,23 @@ static std::string read_pipe_data() {
917961
return result.str();
918962
}
919963

964+
static void ctrl_c_handling() {
965+
#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__))
966+
struct sigaction sigint_action;
967+
sigint_action.sa_handler = sigint_handler;
968+
sigemptyset(&sigint_action.sa_mask);
969+
sigint_action.sa_flags = 0;
970+
sigaction(SIGINT, &sigint_action, NULL);
971+
#elif defined(_WIN32)
972+
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
973+
return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false;
974+
};
975+
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
976+
#endif
977+
}
978+
920979
int main(int argc, const char ** argv) {
980+
ctrl_c_handling();
921981
Opt opt;
922982
const int ret = opt.init(argc, argv);
923983
if (ret == 2) {

0 commit comments

Comments
 (0)