Skip to content

Commit 048e90d

Browse files
committed
test(vllm): asynchronous call support
1 parent c1ebd6d commit 048e90d

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

test/backends/test_vllm.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import os
23
import pydantic
34
import pytest
@@ -7,7 +8,7 @@
78
from mellea.backends.vllm import LocalVLLMBackend
89
from mellea.backends.types import ModelOption
910
import mellea.backends.model_ids as model_ids
10-
from mellea.stdlib.base import CBlock, LinearContext
11+
from mellea.stdlib.base import CBlock, LinearContext, SimpleContext
1112
from mellea.stdlib.requirement import (
1213
LLMaJRequirement,
1314
Requirement,
@@ -135,6 +136,44 @@ class Answer(pydantic.BaseModel):
135136
)
136137

137138

139+
@pytest.mark.qualitative
140+
def test_async_parallel_requests(session):
141+
async def parallel_requests():
142+
model_opts = {ModelOption.STREAM: True}
143+
mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts)
144+
mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts)
145+
146+
m1_val = None
147+
m2_val = None
148+
if not mot1.is_computed():
149+
m1_val = await mot1.astream()
150+
if not mot2.is_computed():
151+
m2_val = await mot2.astream()
152+
153+
assert m1_val is not None, "should be a string val after generation"
154+
assert m2_val is not None, "should be a string val after generation"
155+
156+
m1_final_val = await mot1.avalue()
157+
m2_final_val = await mot2.avalue()
158+
159+
# Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response
160+
# contains the full response.
161+
assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk"
162+
assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk"
163+
164+
assert m1_final_val == mot1.value
165+
assert m2_final_val == mot2.value
166+
asyncio.run(parallel_requests())
167+
168+
@pytest.mark.qualitative
169+
def test_async_avalue(session):
170+
async def avalue():
171+
mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext())
172+
m1_final_val = await mot1.avalue()
173+
assert m1_final_val is not None
174+
assert m1_final_val == mot1.value
175+
asyncio.run(avalue())
176+
138177
if __name__ == "__main__":
139178
import pytest
140179

0 commit comments

Comments
 (0)