Skip to content

Commit d4e0bad

Browse files
committed
server : (embeddings) using same format for "input" and "content"
1 parent 081b29b commit d4e0bad

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

examples/server/server.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3651,14 +3651,14 @@ int main(int argc, char ** argv) {
36513651
const json body = json::parse(req.body);
36523652
bool oaicompat = false;
36533653

3654-
// an input prompt can be a string or a list of tokens (integer)
3654+
// for the shape of input/content, see tokenize_input_prompts()
36553655
json prompt;
3656-
if (body.count("input") != 0) {
3656+
if (body.contains("input")) {
36573657
oaicompat = true;
36583658
prompt = body.at("input");
3659-
} else if (body.count("content") != 0) {
3660-
// with "content", we only support single prompt
3661-
prompt = std::vector<std::string>{body.at("content")};
3659+
} else if (body.contains("content")) {
3660+
oaicompat = false;
3661+
prompt = body.at("content");
36623662
} else {
36633663
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
36643664
return;

examples/server/tests/unit/test_embedding.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,31 @@ def test_embedding_multiple():
4545
assert len(d['embedding']) > 1
4646

4747

48+
@pytest.mark.parametrize(
49+
"content",
50+
[
51+
# single prompt
52+
"string",
53+
[12, 34, 56],
54+
[12, 34, "string", 56, 78],
55+
# multiple prompts
56+
["string1", "string2"],
57+
["string1", [12, 34, 56]],
58+
[[12, 34, 56], [12, 34, 56]],
59+
[[12, 34, 56], [12, "string", 34, 56]],
60+
]
61+
)
62+
def test_embedding_mixed_input(content):
63+
global server
64+
server.start()
65+
res = server.make_request("POST", "/embeddings", data={"content": content})
66+
assert res.status_code == 200
67+
assert len(res.body['data']) == len(content)
68+
for d in res.body['data']:
69+
assert 'embedding' in d
70+
assert len(d['embedding']) > 1
71+
72+
4873
def test_embedding_openai_library_single():
4974
global server
5075
server.start()

examples/server/utils.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
138138
* and multiple prompts (multi-tasks):
139139
* - "prompt": ["string1", "string2"]
140140
* - "prompt": ["string1", [12, 34, 56]]
141+
* - "prompt": [[12, 34, 56], [78, 90, 12]]
141142
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
142143
*/
143144
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {

0 commit comments

Comments
 (0)