Skip to content

Commit 38725ef

Browse files
committed
server : add bad input handling in embeddings
1 parent 4f51968 commit 38725ef

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

examples/server/server.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,20 +3649,30 @@ int main(int argc, char ** argv) {
36493649
oaicompat = true;
36503650
prompt = body.at("input");
36513651
} else if (body.count("content") != 0) {
3652-
// with "content", we only support single prompt
3653-
prompt = std::vector<std::string>{body.at("content")};
3652+
prompt = body.at("content");
36543653
} else {
36553654
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
36563655
return;
36573656
}
36583657

3658+
// with "content", we only support single prompt
3659+
if (!oaicompat && prompt.type() != json::value_t::string) {
3660+
res_error(res, format_error_response("\"content\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3661+
return;
3662+
}
3663+
36593664
// create and queue the task
36603665
json responses = json::array();
36613666
bool error = false;
36623667
{
36633668
std::vector<server_task> tasks;
36643669
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
36653670
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3671+
if (tokenized_prompts[i].size() == 0) {
3672+
res_error(res, format_error_response("input cannot be an empty string", ERROR_TYPE_INVALID_REQUEST));
3673+
return;
3674+
}
3675+
36663676
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
36673677
task.id = ctx_server.queue_tasks.get_new_id();
36683678
task.index = i;

examples/server/tests/unit/test_embedding.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,39 @@ def test_same_prompt_give_same_result():
9797
vi = res.body['data'][i]['embedding']
9898
for x, y in zip(v0, vi):
9999
assert abs(x - y) < EPSILON
100+
101+
102+
@pytest.mark.parametrize("text", [
103+
None,
104+
True,
105+
"",
106+
42,
107+
4.2,
108+
{},
109+
[],
110+
[""],
111+
["This is a test", ""],
112+
])
113+
def test_embedding_bad_input(text):
114+
global server
115+
server.start()
116+
res = server.make_request("POST", "/embeddings", data={"input": text})
117+
assert res.status_code >= 400
118+
119+
120+
@pytest.mark.parametrize("text", [
121+
None,
122+
True,
123+
"",
124+
42,
125+
4.2,
126+
{},
127+
[],
128+
[""],
129+
["This is a test"],
130+
])
131+
def test_embedding_content_bad_input(text):
132+
global server
133+
server.start()
134+
res = server.make_request("POST", "/embeddings", data={"content": text})
135+
assert res.status_code >= 400

0 commit comments

Comments
 (0)