Skip to content

Commit 1cdcf36

Browse files
committed
DRY: Removed mixed arrays for prompt
1 parent abb33e0 commit 1cdcf36

File tree

1 file changed

+12
-33
lines changed

1 file changed

+12
-33
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ struct server_slot {
164164
int32_t n_prompt_tokens = 0;
165165
int32_t n_prompt_tokens_processed = 0;
166166

167-
json prompt; // can be either a string, array of strings, array of token ids, or mixed array of strings and token ids
167+
json prompt; // can be either a string, array of strings or array of token ids
168168

169169
json input_prefix;
170170
json input_suffix;
@@ -906,20 +906,10 @@ struct server_context {
906906
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
907907

908908
auto dry_sequence_breakers = data.find("dry_sequence_breakers");
909-
if (dry_sequence_breakers != data.end()) {
910-
try {
911-
if (dry_sequence_breakers->is_array()) {
912-
slot.sparams.dry_sequence_breakers = dry_sequence_breakers->get<std::vector<std::string>>();
913-
if (slot.sparams.dry_sequence_breakers.empty()) {
914-
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
915-
return false;
916-
}
917-
} else {
918-
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
919-
return false;
920-
}
921-
} catch (const std::exception & e) {
922-
send_error(task, std::string("\"dry_sequence_breakers\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
909+
if (data.contains("dry_sequence_breakers")) {
910+
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
911+
if (slot.sparams.dry_sequence_breakers.empty()) {
912+
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
923913
return false;
924914
}
925915
}
@@ -979,21 +969,22 @@ struct server_context {
979969
}
980970

981971
if ((prompt->is_string()) ||
982-
(prompt->is_array() && !prompt->empty() && (prompt->at(0).is_string() || prompt->at(0).is_number_integer()))) {
972+
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
973+
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
983974
slot.prompt = *prompt;
984975
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
985976
slot.prompt = prompt->at(0);
986977
} else if (prompt->is_array() && prompt->size() > 1) {
987978
// array of strings
988979
for (const auto & el : *prompt) {
989980
if (!el.is_string()) {
990-
send_error(task, "\"prompt\" must be a string, an array of strings, an array of integers, or a mixed array of strings and integers", ERROR_TYPE_INVALID_REQUEST);
981+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
991982
return false;
992983
}
993984
}
994985
slot.prompt = *prompt;
995986
} else {
996-
send_error(task, "\"prompt\" must be a string, an array of strings, an array of integers, or a mixed array of strings and integers", ERROR_TYPE_INVALID_REQUEST);
987+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
997988
return false;
998989
}
999990
}
@@ -1511,27 +1502,15 @@ struct server_context {
15111502
tasks.push_back(std::move(task));
15121503
};
15131504

1514-
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids, a mixed array of strings and token ids, or an array of prompts";
1505+
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
15151506
if (!data.contains("prompt")) {
15161507
throw std::runtime_error(error_msg);
15171508
}
15181509

15191510
json prompt = data.at("prompt");
15201511

1521-
auto is_valid_singleton_array = [](const json& arr) {
1522-
bool has_number = false;
1523-
for (const auto& elem : arr) {
1524-
if (elem.is_number()) {
1525-
has_number = true;
1526-
} else if (!elem.is_string()) {
1527-
return false;
1528-
}
1529-
}
1530-
return has_number;
1531-
};
1532-
1533-
// if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1534-
if (prompt.is_string() || (prompt.is_array() && is_valid_singleton_array(prompt))) {
1512+
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
1513+
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
15351514
data["index"] = 0;
15361515
create_task(data, false, nullptr);
15371516
} else if (prompt.is_array()) {

0 commit comments

Comments
 (0)