Skip to content

Commit 52530f2

Browse files
committed
DRY Sampling - Better merge of latest changes to master in server.cpp
1 parent 2331c79 commit 52530f2

File tree

1 file changed

+33
-56
lines changed

1 file changed

+33
-56
lines changed

examples/server/server.cpp

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1513,70 +1513,48 @@ struct server_context {
15131513
}
15141514
tasks.push_back(std::move(task));
15151515
};
1516-
1517-
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
1516+
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids, a mixed array of strings and token ids, or an array of prompts";
15181517
if (!data.contains("prompt")) {
15191518
throw std::runtime_error(error_msg);
15201519
}
1521-
15221520
json prompt = data.at("prompt");
1523-
1524-
// The commented out code removed the previous ability to submit a mixed array of strings and token IDs
1525-
1526-
// // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
1527-
// if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
1528-
// data["index"] = 0;
1529-
// create_task(data, false, nullptr);
1530-
// }
1531-
// // otherwise, it's a multiple-prompt task, we break it into smaller tasks
1532-
// else if (prompt.is_array()) {
1533-
// std::vector<json> prompts = prompt;
1534-
// for (size_t i = 0; i < prompts.size(); i++) {
1535-
// const auto & e = prompts[i];
1536-
// if (e.is_string() || json_is_array_of_numbers(e)) {
1537-
// data["index"] = i;
1538-
// create_task(data, true, e);
1539-
// } else {
1540-
// throw std::runtime_error(error_msg);
1541-
// }
1542-
// }
1543-
// }
1544-
// // invalid case
1545-
// else {
1546-
// throw std::runtime_error(error_msg);
1547-
// }
1548-
1549-
// Single string prompt
1550-
if (prompt.is_string()) {
1521+
// if the prompt is a singleton (i.e. a string, a list of tokens, or a mixed array of strings and tokens), we only need to create a single task
1522+
if (prompt.is_string() || (prompt.is_array() && !prompt.empty() && !prompt[0].is_array())) {
1523+
bool is_mixed = false;
1524+
bool has_string = prompt.is_string();
1525+
bool has_number = false;
1526+
if (prompt.is_array()) {
1527+
for (const auto& elem : prompt) {
1528+
if (elem.is_string()) has_string = true;
1529+
else if (elem.is_number()) has_number = true;
1530+
if (has_string && has_number) {
1531+
is_mixed = true;
1532+
break;
1533+
}
1534+
}
1535+
}
15511536
data["index"] = 0;
15521537
create_task(data, false, nullptr);
1538+
SRV_DBG("creating single%s prompt task\n", is_mixed ? " mixed" : "");
15531539
}
1554-
// Single array prompt (could be all tokens, all strings, or mixed)
1540+
// otherwise, it's a multiple-prompt task or a rerank task, we break it into smaller tasks
15551541
else if (prompt.is_array()) {
1556-
bool is_mixed = false;
1557-
bool has_string = false;
1558-
bool has_number = false;
1559-
for (const auto& elem : prompt) {
1560-
if (elem.is_string()) has_string = true;
1561-
else if (elem.is_number()) has_number = true;
1562-
if (has_string && has_number) {
1563-
is_mixed = true;
1564-
break;
1542+
std::vector<json> prompts = prompt;
1543+
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1544+
// prompts[0] is the question
1545+
// the rest are the answers/documents
1546+
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
1547+
for (size_t i = 1; i < prompts.size(); i++) {
1548+
json qd;
1549+
qd.push_back(prompts[0]);
1550+
qd.push_back(prompts[i]);
1551+
data["index"] = i - 1;
1552+
create_task(data, true, qd);
15651553
}
1566-
}
1567-
1568-
if (is_mixed || (has_string && !has_number)) {
1569-
// Mixed array or array of strings, treat as single prompt
1570-
data["index"] = 0;
1571-
create_task(data, false, nullptr);
1572-
} else if (!has_string && has_number) {
1573-
// Array of token IDs
1574-
data["index"] = 0;
1575-
create_task(data, false, nullptr);
15761554
} else {
1577-
// Array of prompts
1578-
for (size_t i = 0; i < prompt.size(); i++) {
1579-
const auto & e = prompt[i];
1555+
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
1556+
for (size_t i = 0; i < prompts.size(); i++) {
1557+
const auto & e = prompts[i];
15801558
if (e.is_string() || json_is_array_of_numbers(e)) {
15811559
data["index"] = i;
15821560
create_task(data, true, e);
@@ -1586,11 +1564,10 @@ struct server_context {
15861564
}
15871565
}
15881566
}
1589-
// Invalid case
1567+
// invalid case
15901568
else {
15911569
throw std::runtime_error(error_msg);
15921570
}
1593-
15941571
return tasks;
15951572
}
15961573

0 commit comments

Comments
 (0)