Skip to content

Commit 64db527

Browse files
committed
Enhance user input handling for llama-run
The main motivation for this change is it was not handing ctrl-d correctly. Modify `read_user_input` to handle EOF, "/bye" command, and empty input cases. Introduce `get_user_input` function to manage user input loop and handle different return cases. Signed-off-by: Eric Curtin <[email protected]>
1 parent 99a3755 commit 64db527

File tree

1 file changed

+35
-3
lines changed

1 file changed

+35
-3
lines changed

examples/run/run.cpp

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,20 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
801801

802802
static int read_user_input(std::string & user) {
803803
std::getline(std::cin, user);
804-
return user.empty(); // Should have data in happy path
804+
if (std::cin.eof()) {
805+
printf("\n");
806+
return 1;
807+
}
808+
809+
if (user == "/bye") {
810+
return 1;
811+
}
812+
813+
if (user.empty()) {
814+
return 2;
815+
}
816+
817+
return 0; // Should have data in happy path
805818
}
806819

807820
// Function to generate a response based on the prompt
@@ -868,15 +881,34 @@ static bool is_stdout_a_terminal() {
868881
#endif
869882
}
870883

871-
// Function to tokenize the prompt
884+
// Function to handle user input
885+
static int get_user_input(std::string & user_input, const std::string & user) {
886+
while (true) {
887+
const int ret = handle_user_input(user_input, user);
888+
if (ret == 1) {
889+
return 1;
890+
}
891+
892+
if (ret == 2) {
893+
continue;
894+
}
895+
896+
break;
897+
}
898+
899+
return 0;
900+
}
901+
902+
// Main chat loop function
872903
static int chat_loop(LlamaData & llama_data, const std::string & user) {
873904
int prev_len = 0;
874905
llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get()));
875906
static const bool stdout_a_terminal = is_stdout_a_terminal();
876907
while (true) {
877908
// Get user input
878909
std::string user_input;
879-
while (handle_user_input(user_input, user)) {
910+
if (get_user_input(user_input, user) == 1) {
911+
return 0;
880912
}
881913

882914
add_message("user", user.empty() ? user_input : user, llama_data);

0 commit comments

Comments
 (0)