@@ -46,28 +46,32 @@ def test_embedding_multiple():
4646
4747
4848@pytest .mark .parametrize (
49- "content" ,
49+ "content,is_multi_prompt " ,
5050 [
5151 # single prompt
52- "string" ,
53- [12 , 34 , 56 ],
54- [12 , 34 , "string" , 56 , 78 ],
52+ ( "string" , False ) ,
53+ ( [12 , 34 , 56 ], False ) ,
54+ ( [12 , 34 , "string" , 56 , 78 ], False ) ,
5555 # multiple prompts
56- ["string1" , "string2" ],
57- ["string1" , [12 , 34 , 56 ]],
58- [[12 , 34 , 56 ], [12 , 34 , 56 ]],
59- [[12 , 34 , 56 ], [12 , "string" , 34 , 56 ]],
56+ ( ["string1" , "string2" ], True ) ,
57+ ( ["string1" , [12 , 34 , 56 ]], True ) ,
58+ ( [[12 , 34 , 56 ], [12 , 34 , 56 ]], True ) ,
59+ ( [[12 , 34 , 56 ], [12 , "string" , 34 , 56 ]], True ) ,
6060 ]
6161)
62- def test_embedding_mixed_input (content ):
62+ def test_embedding_mixed_input (content , is_multi_prompt : bool ):
6363 global server
6464 server .start ()
6565 res = server .make_request ("POST" , "/embeddings" , data = {"content" : content })
6666 assert res .status_code == 200
67- assert len (res .body ['data' ]) == len (content )
68- for d in res .body ['data' ]:
69- assert 'embedding' in d
70- assert len (d ['embedding' ]) > 1
67+ if is_multi_prompt :
68+ assert len (res .body ) == len (content )
69+ for d in res .body :
70+ assert 'embedding' in d
71+ assert len (d ['embedding' ]) > 1
72+ else :
73+ assert 'embedding' in res .body
74+ assert len (res .body ['embedding' ]) > 1
7175
7276
7377def test_embedding_openai_library_single ():
0 commit comments