@@ -81,9 +81,11 @@ def generate():
8181
8282 client = Client (httpserver .url_for ('/' ))
8383 response = client .chat ('dummy' , messages = [{'role' : 'user' , 'content' : 'Why is the sky blue?' }], stream = True )
84+
85+ it = iter (['I ' , "don't " , 'know.' ])
8486 for part in response :
8587 assert part ['message' ]['role' ] in 'assistant'
86- assert part ['message' ]['content' ] in [ 'I ' , "don't " , 'know.' ]
88+ assert part ['message' ]['content' ] == next ( it )
8789
8890
8991def test_client_chat_images (httpserver : HTTPServer ):
@@ -187,9 +189,11 @@ def generate():
187189
188190 client = Client (httpserver .url_for ('/' ))
189191 response = client .generate ('dummy' , 'Why is the sky blue?' , stream = True )
192+
193+ it = iter (['Because ' , 'it ' , 'is.' ])
190194 for part in response :
191195 assert part ['model' ] == 'dummy'
192- assert part ['response' ] in [ 'Because ' , 'it ' , 'is.' ]
196+ assert part ['response' ] == next ( it )
193197
194198
195199def test_client_generate_images (httpserver : HTTPServer ):
@@ -458,6 +462,24 @@ async def test_async_client_chat(httpserver: HTTPServer):
458462
459463@pytest .mark .asyncio
460464async def test_async_client_chat_stream (httpserver : HTTPServer ):
465+ def stream_handler (_ : Request ):
466+ def generate ():
467+ for message in ['I ' , "don't " , 'know.' ]:
468+ yield (
469+ json .dumps (
470+ {
471+ 'model' : 'dummy' ,
472+ 'message' : {
473+ 'role' : 'assistant' ,
474+ 'content' : message ,
475+ },
476+ }
477+ )
478+ + '\n '
479+ )
480+
481+ return Response (generate ())
482+
461483 httpserver .expect_ordered_request (
462484 '/api/chat' ,
463485 method = 'POST' ,
@@ -468,11 +490,15 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
468490 'format' : '' ,
469491 'options' : {},
470492 },
471- ).respond_with_json ({} )
493+ ).respond_with_handler ( stream_handler )
472494
473495 client = AsyncClient (httpserver .url_for ('/' ))
474496 response = await client .chat ('dummy' , messages = [{'role' : 'user' , 'content' : 'Why is the sky blue?' }], stream = True )
475- assert isinstance (response , types .AsyncGeneratorType )
497+
498+ it = iter (['I ' , "don't " , 'know.' ])
499+ async for part in response :
500+ assert part ['message' ]['role' ] == 'assistant'
501+ assert part ['message' ]['content' ] == next (it )
476502
477503
478504@pytest .mark .asyncio
@@ -529,6 +555,21 @@ async def test_async_client_generate(httpserver: HTTPServer):
529555
530556@pytest .mark .asyncio
531557async def test_async_client_generate_stream (httpserver : HTTPServer ):
558+ def stream_handler (_ : Request ):
559+ def generate ():
560+ for message in ['Because ' , 'it ' , 'is.' ]:
561+ yield (
562+ json .dumps (
563+ {
564+ 'model' : 'dummy' ,
565+ 'response' : message ,
566+ }
567+ )
568+ + '\n '
569+ )
570+
571+ return Response (generate ())
572+
532573 httpserver .expect_ordered_request (
533574 '/api/generate' ,
534575 method = 'POST' ,
@@ -544,11 +585,15 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
544585 'format' : '' ,
545586 'options' : {},
546587 },
547- ).respond_with_json ({} )
588+ ).respond_with_handler ( stream_handler )
548589
549590 client = AsyncClient (httpserver .url_for ('/' ))
550591 response = await client .generate ('dummy' , 'Why is the sky blue?' , stream = True )
551- assert isinstance (response , types .AsyncGeneratorType )
592+
593+ it = iter (['Because ' , 'it ' , 'is.' ])
594+ async for part in response :
595+ assert part ['model' ] == 'dummy'
596+ assert part ['response' ] == next (it )
552597
553598
554599@pytest .mark .asyncio
0 commit comments