Skip to content

Commit 2e0b2e1

Browse files
authored
Merge pull request #18 from jmorganca/mxyng/fix-tests
fix unit tests
2 parents 38f68e2 + 5c1df78 commit 2e0b2e1

File tree

1 file changed

+97
-15
lines changed

1 file changed

+97
-15
lines changed

tests/test_client.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import io
33
import json
4-
import types
54
import pytest
65
import tempfile
76
from pathlib import Path
@@ -81,9 +80,11 @@ def generate():
8180

8281
client = Client(httpserver.url_for('/'))
8382
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
83+
84+
it = iter(['I ', "don't ", 'know.'])
8485
for part in response:
8586
assert part['message']['role'] in 'assistant'
86-
assert part['message']['content'] in ['I ', "don't ", 'know.']
87+
assert part['message']['content'] == next(it)
8788

8889

8990
def test_client_chat_images(httpserver: HTTPServer):
@@ -187,9 +188,11 @@ def generate():
187188

188189
client = Client(httpserver.url_for('/'))
189190
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
191+
192+
it = iter(['Because ', 'it ', 'is.'])
190193
for part in response:
191194
assert part['model'] == 'dummy'
192-
assert part['response'] in ['Because ', 'it ', 'is.']
195+
assert part['response'] == next(it)
193196

194197

195198
def test_client_generate_images(httpserver: HTTPServer):
@@ -263,11 +266,14 @@ def generate():
263266
'insecure': False,
264267
'stream': True,
265268
},
266-
).respond_with_json({})
269+
).respond_with_handler(stream_handler)
267270

268271
client = Client(httpserver.url_for('/'))
269272
response = client.pull('dummy', stream=True)
270-
assert isinstance(response, types.GeneratorType)
273+
274+
it = iter(['pulling manifest', 'verifying sha256 digest', 'writing manifest', 'removing any unused layers', 'success'])
275+
for part in response:
276+
assert part['status'] == next(it)
271277

272278

273279
def test_client_push(httpserver: HTTPServer):
@@ -287,6 +293,14 @@ def test_client_push(httpserver: HTTPServer):
287293

288294

289295
def test_client_push_stream(httpserver: HTTPServer):
296+
def stream_handler(_: Request):
297+
def generate():
298+
yield json.dumps({'status': 'retrieving manifest'}) + '\n'
299+
yield json.dumps({'status': 'pushing manifest'}) + '\n'
300+
yield json.dumps({'status': 'success'}) + '\n'
301+
302+
return Response(generate())
303+
290304
httpserver.expect_ordered_request(
291305
'/api/push',
292306
method='POST',
@@ -295,11 +309,14 @@ def test_client_push_stream(httpserver: HTTPServer):
295309
'insecure': False,
296310
'stream': True,
297311
},
298-
).respond_with_json({})
312+
).respond_with_handler(stream_handler)
299313

300314
client = Client(httpserver.url_for('/'))
301315
response = client.push('dummy', stream=True)
302-
assert isinstance(response, types.GeneratorType)
316+
317+
it = iter(['retrieving manifest', 'pushing manifest', 'success'])
318+
for part in response:
319+
assert part['status'] == next(it)
303320

304321

305322
def test_client_create_path(httpserver: HTTPServer):
@@ -458,6 +475,24 @@ async def test_async_client_chat(httpserver: HTTPServer):
458475

459476
@pytest.mark.asyncio
460477
async def test_async_client_chat_stream(httpserver: HTTPServer):
478+
def stream_handler(_: Request):
479+
def generate():
480+
for message in ['I ', "don't ", 'know.']:
481+
yield (
482+
json.dumps(
483+
{
484+
'model': 'dummy',
485+
'message': {
486+
'role': 'assistant',
487+
'content': message,
488+
},
489+
}
490+
)
491+
+ '\n'
492+
)
493+
494+
return Response(generate())
495+
461496
httpserver.expect_ordered_request(
462497
'/api/chat',
463498
method='POST',
@@ -468,11 +503,15 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
468503
'format': '',
469504
'options': {},
470505
},
471-
).respond_with_json({})
506+
).respond_with_handler(stream_handler)
472507

473508
client = AsyncClient(httpserver.url_for('/'))
474509
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
475-
assert isinstance(response, types.AsyncGeneratorType)
510+
511+
it = iter(['I ', "don't ", 'know.'])
512+
async for part in response:
513+
assert part['message']['role'] == 'assistant'
514+
assert part['message']['content'] == next(it)
476515

477516

478517
@pytest.mark.asyncio
@@ -529,6 +568,21 @@ async def test_async_client_generate(httpserver: HTTPServer):
529568

530569
@pytest.mark.asyncio
531570
async def test_async_client_generate_stream(httpserver: HTTPServer):
571+
def stream_handler(_: Request):
572+
def generate():
573+
for message in ['Because ', 'it ', 'is.']:
574+
yield (
575+
json.dumps(
576+
{
577+
'model': 'dummy',
578+
'response': message,
579+
}
580+
)
581+
+ '\n'
582+
)
583+
584+
return Response(generate())
585+
532586
httpserver.expect_ordered_request(
533587
'/api/generate',
534588
method='POST',
@@ -544,11 +598,15 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
544598
'format': '',
545599
'options': {},
546600
},
547-
).respond_with_json({})
601+
).respond_with_handler(stream_handler)
548602

549603
client = AsyncClient(httpserver.url_for('/'))
550604
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
551-
assert isinstance(response, types.AsyncGeneratorType)
605+
606+
it = iter(['Because ', 'it ', 'is.'])
607+
async for part in response:
608+
assert part['model'] == 'dummy'
609+
assert part['response'] == next(it)
552610

553611

554612
@pytest.mark.asyncio
@@ -597,6 +655,16 @@ async def test_async_client_pull(httpserver: HTTPServer):
597655

598656
@pytest.mark.asyncio
599657
async def test_async_client_pull_stream(httpserver: HTTPServer):
658+
def stream_handler(_: Request):
659+
def generate():
660+
yield json.dumps({'status': 'pulling manifest'}) + '\n'
661+
yield json.dumps({'status': 'verifying sha256 digest'}) + '\n'
662+
yield json.dumps({'status': 'writing manifest'}) + '\n'
663+
yield json.dumps({'status': 'removing any unused layers'}) + '\n'
664+
yield json.dumps({'status': 'success'}) + '\n'
665+
666+
return Response(generate())
667+
600668
httpserver.expect_ordered_request(
601669
'/api/pull',
602670
method='POST',
@@ -605,11 +673,14 @@ async def test_async_client_pull_stream(httpserver: HTTPServer):
605673
'insecure': False,
606674
'stream': True,
607675
},
608-
).respond_with_json({})
676+
).respond_with_handler(stream_handler)
609677

610678
client = AsyncClient(httpserver.url_for('/'))
611679
response = await client.pull('dummy', stream=True)
612-
assert isinstance(response, types.AsyncGeneratorType)
680+
681+
it = iter(['pulling manifest', 'verifying sha256 digest', 'writing manifest', 'removing any unused layers', 'success'])
682+
async for part in response:
683+
assert part['status'] == next(it)
613684

614685

615686
@pytest.mark.asyncio
@@ -631,6 +702,14 @@ async def test_async_client_push(httpserver: HTTPServer):
631702

632703
@pytest.mark.asyncio
633704
async def test_async_client_push_stream(httpserver: HTTPServer):
705+
def stream_handler(_: Request):
706+
def generate():
707+
yield json.dumps({'status': 'retrieving manifest'}) + '\n'
708+
yield json.dumps({'status': 'pushing manifest'}) + '\n'
709+
yield json.dumps({'status': 'success'}) + '\n'
710+
711+
return Response(generate())
712+
634713
httpserver.expect_ordered_request(
635714
'/api/push',
636715
method='POST',
@@ -639,11 +718,14 @@ async def test_async_client_push_stream(httpserver: HTTPServer):
639718
'insecure': False,
640719
'stream': True,
641720
},
642-
).respond_with_json({})
721+
).respond_with_handler(stream_handler)
643722

644723
client = AsyncClient(httpserver.url_for('/'))
645724
response = await client.push('dummy', stream=True)
646-
assert isinstance(response, types.AsyncGeneratorType)
725+
726+
it = iter(['retrieving manifest', 'pushing manifest', 'success'])
727+
async for part in response:
728+
assert part['status'] == next(it)
647729

648730

649731
@pytest.mark.asyncio

0 commit comments

Comments
 (0)