|
| 1 | +import asyncio |
1 | 2 | import os |
2 | 3 | import pydantic |
3 | 4 | import pytest |
|
7 | 8 | from mellea.backends.vllm import LocalVLLMBackend |
8 | 9 | from mellea.backends.types import ModelOption |
9 | 10 | import mellea.backends.model_ids as model_ids |
10 | | -from mellea.stdlib.base import CBlock, LinearContext |
| 11 | +from mellea.stdlib.base import CBlock, LinearContext, SimpleContext |
11 | 12 | from mellea.stdlib.requirement import ( |
12 | 13 | LLMaJRequirement, |
13 | 14 | Requirement, |
@@ -135,6 +136,44 @@ class Answer(pydantic.BaseModel): |
135 | 136 | ) |
136 | 137 |
|
137 | 138 |
|
| 139 | +@pytest.mark.qualitative |
| 140 | +def test_async_parallel_requests(session): |
| 141 | + async def parallel_requests(): |
| 142 | + model_opts = {ModelOption.STREAM: True} |
| 143 | + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) |
| 144 | + mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) |
| 145 | + |
| 146 | + m1_val = None |
| 147 | + m2_val = None |
| 148 | + if not mot1.is_computed(): |
| 149 | + m1_val = await mot1.astream() |
| 150 | + if not mot2.is_computed(): |
| 151 | + m2_val = await mot2.astream() |
| 152 | + |
| 153 | + assert m1_val is not None, "should be a string val after generation" |
| 154 | + assert m2_val is not None, "should be a string val after generation" |
| 155 | + |
| 156 | + m1_final_val = await mot1.avalue() |
| 157 | + m2_final_val = await mot2.avalue() |
| 158 | + |
| 159 | + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response |
| 160 | + # contains the full response. |
| 161 | + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" |
| 162 | + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" |
| 163 | + |
| 164 | + assert m1_final_val == mot1.value |
| 165 | + assert m2_final_val == mot2.value |
| 166 | + asyncio.run(parallel_requests()) |
| 167 | + |
| 168 | +@pytest.mark.qualitative |
| 169 | +def test_async_avalue(session): |
| 170 | + async def avalue(): |
| 171 | + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) |
| 172 | + m1_final_val = await mot1.avalue() |
| 173 | + assert m1_final_val is not None |
| 174 | + assert m1_final_val == mot1.value |
| 175 | + asyncio.run(avalue()) |
| 176 | + |
138 | 177 | if __name__ == "__main__": |
139 | 178 | import pytest |
140 | 179 |
|
|
0 commit comments