@@ -71,8 +71,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
7171 })
7272 content = ""
7373 last_cmpl_id = None
74- for data in res :
74+ for i , data in enumerate ( res ) :
7575 choice = data ["choices" ][0 ]
76+ if i == 0 :
77+ # Check first role message for stream=True
78+ assert choice ["delta" ]["content" ] == ""
79+ assert choice ["delta" ]["role" ] == "assistant"
80+ else :
81+ assert "role" not in choice ["delta" ]
7682 assert data ["system_fingerprint" ].startswith ("b" )
7783 assert "gpt-3.5" in data ["model" ] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
7884 if last_cmpl_id is None :
@@ -242,12 +248,18 @@ def test_chat_completion_with_timings_per_token():
242248 "stream" : True ,
243249 "timings_per_token" : True ,
244250 })
245- for data in res :
246- assert "timings" in data
247- assert "prompt_per_second" in data ["timings" ]
248- assert "predicted_per_second" in data ["timings" ]
249- assert "predicted_n" in data ["timings" ]
250- assert data ["timings" ]["predicted_n" ] <= 10
251+ for i , data in enumerate (res ):
252+ if i == 0 :
253+ # Check first role message for stream=True
254+ assert data ["choices" ][0 ]["delta" ]["content" ] == ""
255+ assert data ["choices" ][0 ]["delta" ]["role" ] == "assistant"
256+ else :
257+ assert "role" not in data ["choices" ][0 ]["delta" ]
258+ assert "timings" in data
259+ assert "prompt_per_second" in data ["timings" ]
260+ assert "predicted_per_second" in data ["timings" ]
261+ assert "predicted_n" in data ["timings" ]
262+ assert data ["timings" ]["predicted_n" ] <= 10
251263
252264
253265def test_logprobs ():
@@ -295,17 +307,23 @@ def test_logprobs_stream():
295307 )
296308 output_text = ''
297309 aggregated_text = ''
298- for data in res :
310+ for i , data in enumerate ( res ) :
299311 choice = data .choices [0 ]
300- if choice .finish_reason is None :
301- if choice .delta .content :
302- output_text += choice .delta .content
303- assert choice .logprobs is not None
304- assert choice .logprobs .content is not None
305- for token in choice .logprobs .content :
306- aggregated_text += token .token
307- assert token .logprob <= 0.0
308- assert token .bytes is not None
309- assert token .top_logprobs is not None
310- assert len (token .top_logprobs ) > 0
312+ if i == 0 :
313+ # Check first role message for stream=True
314+ assert choice .delta .content == ""
315+ assert choice .delta .role == "assistant"
316+ else :
317+ assert choice .delta .role is None
318+ if choice .finish_reason is None :
319+ if choice .delta .content :
320+ output_text += choice .delta .content
321+ assert choice .logprobs is not None
322+ assert choice .logprobs .content is not None
323+ for token in choice .logprobs .content :
324+ aggregated_text += token .token
325+ assert token .logprob <= 0.0
326+ assert token .bytes is not None
327+ assert token .top_logprobs is not None
328+ assert len (token .top_logprobs ) > 0
311329 assert aggregated_text == output_text
0 commit comments