Skip to content

Commit 21aad84

Browse files
committed
fix: update async stream tests
1 parent 38f68e2 commit 21aad84

File tree

1 file changed

+51
-6
lines changed

1 file changed

+51
-6
lines changed

tests/test_client.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def generate():
8181

8282
client = Client(httpserver.url_for('/'))
8383
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
84+
85+
it = iter(['I ', "don't ", 'know.'])
8486
for part in response:
8587
assert part['message']['role'] in 'assistant'
86-
assert part['message']['content'] in ['I ', "don't ", 'know.']
88+
assert part['message']['content'] == next(it)
8789

8890

8991
def test_client_chat_images(httpserver: HTTPServer):
@@ -187,9 +189,11 @@ def generate():
187189

188190
client = Client(httpserver.url_for('/'))
189191
response = client.generate('dummy', 'Why is the sky blue?', stream=True)
192+
193+
it = iter(['Because ', 'it ', 'is.'])
190194
for part in response:
191195
assert part['model'] == 'dummy'
192-
assert part['response'] in ['Because ', 'it ', 'is.']
196+
assert part['response'] == next(it)
193197

194198

195199
def test_client_generate_images(httpserver: HTTPServer):
@@ -458,6 +462,24 @@ async def test_async_client_chat(httpserver: HTTPServer):
458462

459463
@pytest.mark.asyncio
460464
async def test_async_client_chat_stream(httpserver: HTTPServer):
465+
def stream_handler(_: Request):
466+
def generate():
467+
for message in ['I ', "don't ", 'know.']:
468+
yield (
469+
json.dumps(
470+
{
471+
'model': 'dummy',
472+
'message': {
473+
'role': 'assistant',
474+
'content': message,
475+
},
476+
}
477+
)
478+
+ '\n'
479+
)
480+
481+
return Response(generate())
482+
461483
httpserver.expect_ordered_request(
462484
'/api/chat',
463485
method='POST',
@@ -468,11 +490,15 @@ async def test_async_client_chat_stream(httpserver: HTTPServer):
468490
'format': '',
469491
'options': {},
470492
},
471-
).respond_with_json({})
493+
).respond_with_handler(stream_handler)
472494

473495
client = AsyncClient(httpserver.url_for('/'))
474496
response = await client.chat('dummy', messages=[{'role': 'user', 'content': 'Why is the sky blue?'}], stream=True)
475-
assert isinstance(response, types.AsyncGeneratorType)
497+
498+
it = iter(['I ', "don't ", 'know.'])
499+
async for part in response:
500+
assert part['message']['role'] == 'assistant'
501+
assert part['message']['content'] == next(it)
476502

477503

478504
@pytest.mark.asyncio
@@ -529,6 +555,21 @@ async def test_async_client_generate(httpserver: HTTPServer):
529555

530556
@pytest.mark.asyncio
531557
async def test_async_client_generate_stream(httpserver: HTTPServer):
558+
def stream_handler(_: Request):
559+
def generate():
560+
for message in ['Because ', 'it ', 'is.']:
561+
yield (
562+
json.dumps(
563+
{
564+
'model': 'dummy',
565+
'response': message,
566+
}
567+
)
568+
+ '\n'
569+
)
570+
571+
return Response(generate())
572+
532573
httpserver.expect_ordered_request(
533574
'/api/generate',
534575
method='POST',
@@ -544,11 +585,15 @@ async def test_async_client_generate_stream(httpserver: HTTPServer):
544585
'format': '',
545586
'options': {},
546587
},
547-
).respond_with_json({})
588+
).respond_with_handler(stream_handler)
548589

549590
client = AsyncClient(httpserver.url_for('/'))
550591
response = await client.generate('dummy', 'Why is the sky blue?', stream=True)
551-
assert isinstance(response, types.AsyncGeneratorType)
592+
593+
it = iter(['Because ', 'it ', 'is.'])
594+
async for part in response:
595+
assert part['model'] == 'dummy'
596+
assert part['response'] == next(it)
552597

553598

554599
@pytest.mark.asyncio

0 commit comments

Comments
 (0)