11import os
22import io
33import json
4- import types
54import pytest
65import tempfile
76from pathlib import Path
@@ -81,9 +80,11 @@ def generate():
8180
8281 client = Client (httpserver .url_for ('/' ))
8382 response = client .chat ('dummy' , messages = [{'role' : 'user' , 'content' : 'Why is the sky blue?' }], stream = True )
83+
84+ it = iter (['I ' , "don't " , 'know.' ])
8485 for part in response :
8586 assert part ['message' ]['role' ] in 'assistant'
86- assert part ['message' ]['content' ] in [ 'I ' , "don't " , 'know.' ]
87+ assert part ['message' ]['content' ] == next ( it )
8788
8889
8990def test_client_chat_images (httpserver : HTTPServer ):
@@ -187,9 +188,11 @@ def generate():
187188
188189 client = Client (httpserver .url_for ('/' ))
189190 response = client .generate ('dummy' , 'Why is the sky blue?' , stream = True )
191+
192+ it = iter (['Because ' , 'it ' , 'is.' ])
190193 for part in response :
191194 assert part ['model' ] == 'dummy'
192- assert part ['response' ] in [ 'Because ' , 'it ' , 'is.' ]
195+ assert part ['response' ] == next ( it )
193196
194197
195198def test_client_generate_images (httpserver : HTTPServer ):
@@ -263,11 +266,14 @@ def generate():
263266 'insecure' : False ,
264267 'stream' : True ,
265268 },
266- ).respond_with_json ({} )
269+ ).respond_with_handler ( stream_handler )
267270
268271 client = Client (httpserver .url_for ('/' ))
269272 response = client .pull ('dummy' , stream = True )
270- assert isinstance (response , types .GeneratorType )
273+
274+ it = iter (['pulling manifest' , 'verifying sha256 digest' , 'writing manifest' , 'removing any unused layers' , 'success' ])
275+ for part in response :
276+ assert part ['status' ] == next (it )
271277
272278
273279def test_client_push (httpserver : HTTPServer ):
@@ -287,6 +293,14 @@ def test_client_push(httpserver: HTTPServer):
287293
288294
289295def test_client_push_stream (httpserver : HTTPServer ):
296+ def stream_handler (_ : Request ):
297+ def generate ():
298+ yield json .dumps ({'status' : 'retrieving manifest' }) + '\n '
299+ yield json .dumps ({'status' : 'pushing manifest' }) + '\n '
300+ yield json .dumps ({'status' : 'success' }) + '\n '
301+
302+ return Response (generate ())
303+
290304 httpserver .expect_ordered_request (
291305 '/api/push' ,
292306 method = 'POST' ,
@@ -295,11 +309,14 @@ def test_client_push_stream(httpserver: HTTPServer):
295309 'insecure' : False ,
296310 'stream' : True ,
297311 },
298- ).respond_with_json ({} )
312+ ).respond_with_handler ( stream_handler )
299313
300314 client = Client (httpserver .url_for ('/' ))
301315 response = client .push ('dummy' , stream = True )
302- assert isinstance (response , types .GeneratorType )
316+
317+ it = iter (['retrieving manifest' , 'pushing manifest' , 'success' ])
318+ for part in response :
319+ assert part ['status' ] == next (it )
303320
304321
305322def test_client_create_path (httpserver : HTTPServer ):
@@ -458,6 +475,24 @@ async def test_async_client_chat(httpserver: HTTPServer):
458475
459476@pytest .mark .asyncio
460477async def test_async_client_chat_stream (httpserver : HTTPServer ):
478+ def stream_handler (_ : Request ):
479+ def generate ():
480+ for message in ['I ' , "don't " , 'know.' ]:
481+ yield (
482+ json .dumps (
483+ {
484+ 'model' : 'dummy' ,
485+ 'message' : {
486+ 'role' : 'assistant' ,
487+ 'content' : message ,
488+ },
489+ }
490+ )
491+ + '\n '
492+ )
493+
494+ return Response (generate ())
495+
461496 httpserver .expect_ordered_request (
462497 '/api/chat' ,
463498 method = 'POST' ,
@@ -468,11 +503,15 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
468503 'format' : '' ,
469504 'options' : {},
470505 },
471- ).respond_with_json ({} )
506+ ).respond_with_handler ( stream_handler )
472507
473508 client = AsyncClient (httpserver .url_for ('/' ))
474509 response = await client .chat ('dummy' , messages = [{'role' : 'user' , 'content' : 'Why is the sky blue?' }], stream = True )
475- assert isinstance (response , types .AsyncGeneratorType )
510+
511+ it = iter (['I ' , "don't " , 'know.' ])
512+ async for part in response :
513+ assert part ['message' ]['role' ] == 'assistant'
514+ assert part ['message' ]['content' ] == next (it )
476515
477516
478517@pytest .mark .asyncio
@@ -529,6 +568,21 @@ async def test_async_client_generate(httpserver: HTTPServer):
529568
530569@pytest .mark .asyncio
531570async def test_async_client_generate_stream (httpserver : HTTPServer ):
571+ def stream_handler (_ : Request ):
572+ def generate ():
573+ for message in ['Because ' , 'it ' , 'is.' ]:
574+ yield (
575+ json .dumps (
576+ {
577+ 'model' : 'dummy' ,
578+ 'response' : message ,
579+ }
580+ )
581+ + '\n '
582+ )
583+
584+ return Response (generate ())
585+
532586 httpserver .expect_ordered_request (
533587 '/api/generate' ,
534588 method = 'POST' ,
@@ -544,11 +598,15 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
544598 'format' : '' ,
545599 'options' : {},
546600 },
547- ).respond_with_json ({} )
601+ ).respond_with_handler ( stream_handler )
548602
549603 client = AsyncClient (httpserver .url_for ('/' ))
550604 response = await client .generate ('dummy' , 'Why is the sky blue?' , stream = True )
551- assert isinstance (response , types .AsyncGeneratorType )
605+
606+ it = iter (['Because ' , 'it ' , 'is.' ])
607+ async for part in response :
608+ assert part ['model' ] == 'dummy'
609+ assert part ['response' ] == next (it )
552610
553611
554612@pytest .mark .asyncio
@@ -597,6 +655,16 @@ async def test_async_client_pull(httpserver: HTTPServer):
597655
598656@pytest .mark .asyncio
599657async def test_async_client_pull_stream (httpserver : HTTPServer ):
658+ def stream_handler (_ : Request ):
659+ def generate ():
660+ yield json .dumps ({'status' : 'pulling manifest' }) + '\n '
661+ yield json .dumps ({'status' : 'verifying sha256 digest' }) + '\n '
662+ yield json .dumps ({'status' : 'writing manifest' }) + '\n '
663+ yield json .dumps ({'status' : 'removing any unused layers' }) + '\n '
664+ yield json .dumps ({'status' : 'success' }) + '\n '
665+
666+ return Response (generate ())
667+
600668 httpserver .expect_ordered_request (
601669 '/api/pull' ,
602670 method = 'POST' ,
@@ -605,11 +673,14 @@ async def test_async_client_pull_stream(httpserver: HTTPServer):
605673 'insecure' : False ,
606674 'stream' : True ,
607675 },
608- ).respond_with_json ({} )
676+ ).respond_with_handler ( stream_handler )
609677
610678 client = AsyncClient (httpserver .url_for ('/' ))
611679 response = await client .pull ('dummy' , stream = True )
612- assert isinstance (response , types .AsyncGeneratorType )
680+
681+ it = iter (['pulling manifest' , 'verifying sha256 digest' , 'writing manifest' , 'removing any unused layers' , 'success' ])
682+ async for part in response :
683+ assert part ['status' ] == next (it )
613684
614685
615686@pytest .mark .asyncio
@@ -631,6 +702,14 @@ async def test_async_client_push(httpserver: HTTPServer):
631702
632703@pytest .mark .asyncio
633704async def test_async_client_push_stream (httpserver : HTTPServer ):
705+ def stream_handler (_ : Request ):
706+ def generate ():
707+ yield json .dumps ({'status' : 'retrieving manifest' }) + '\n '
708+ yield json .dumps ({'status' : 'pushing manifest' }) + '\n '
709+ yield json .dumps ({'status' : 'success' }) + '\n '
710+
711+ return Response (generate ())
712+
634713 httpserver .expect_ordered_request (
635714 '/api/push' ,
636715 method = 'POST' ,
@@ -639,11 +718,14 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
639718 'insecure' : False ,
640719 'stream' : True ,
641720 },
642- ).respond_with_json ({} )
721+ ).respond_with_handler ( stream_handler )
643722
644723 client = AsyncClient (httpserver .url_for ('/' ))
645724 response = await client .push ('dummy' , stream = True )
646- assert isinstance (response , types .AsyncGeneratorType )
725+
726+ it = iter (['retrieving manifest' , 'pushing manifest' , 'success' ])
727+ async for part in response :
728+ assert part ['status' ] == next (it )
647729
648730
649731@pytest .mark .asyncio
0 commit comments