Skip to content

Commit 9a56680

Browse files
committed
fix test case
1 parent d4e0bad commit 9a56680

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

examples/server/tests/unit/test_embedding.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,28 +46,32 @@ def test_embedding_multiple():
4646

4747

4848
@pytest.mark.parametrize(
49-
"content",
49+
"content,is_multi_prompt",
5050
[
5151
# single prompt
52-
"string",
53-
[12, 34, 56],
54-
[12, 34, "string", 56, 78],
52+
("string", False),
53+
([12, 34, 56], False),
54+
([12, 34, "string", 56, 78], False),
5555
# multiple prompts
56-
["string1", "string2"],
57-
["string1", [12, 34, 56]],
58-
[[12, 34, 56], [12, 34, 56]],
59-
[[12, 34, 56], [12, "string", 34, 56]],
56+
(["string1", "string2"], True),
57+
(["string1", [12, 34, 56]], True),
58+
([[12, 34, 56], [12, 34, 56]], True),
59+
([[12, 34, 56], [12, "string", 34, 56]], True),
6060
]
6161
)
62-
def test_embedding_mixed_input(content):
62+
def test_embedding_mixed_input(content, is_multi_prompt: bool):
6363
global server
6464
server.start()
6565
res = server.make_request("POST", "/embeddings", data={"content": content})
6666
assert res.status_code == 200
67-
assert len(res.body['data']) == len(content)
68-
for d in res.body['data']:
69-
assert 'embedding' in d
70-
assert len(d['embedding']) > 1
67+
if is_multi_prompt:
68+
assert len(res.body) == len(content)
69+
for d in res.body:
70+
assert 'embedding' in d
71+
assert len(d['embedding']) > 1
72+
else:
73+
assert 'embedding' in res.body
74+
assert len(res.body['embedding']) > 1
7175

7276

7377
def test_embedding_openai_library_single():

0 commit comments

Comments
 (0)