|
| 1 | +import asyncio |
1 | 2 | import os |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 |
|
5 | 6 | from mellea.backends.types import ModelOption |
6 | | -from mellea.stdlib.base import ModelOutputThunk |
| 7 | +from mellea.stdlib.base import ChatContext, ModelOutputThunk |
7 | 8 | from mellea.stdlib.chat import Message |
8 | 9 | from mellea.stdlib.session import start_session |
9 | 10 |
|
10 | | -@pytest.fixture(scope="module") |
| 11 | +# We edit the context type in the async tests below. Don't change the scope here. |
| 12 | +@pytest.fixture(scope="function") |
11 | 13 | def m_session(gh_run): |
12 | 14 | if gh_run == 1: |
13 | 15 | m = start_session( |
@@ -66,5 +68,37 @@ async def test_ainstruct(m_session): |
66 | 68 | assert m_session.ctx is not initial_ctx |
67 | 69 | assert out.value is not None |
68 | 70 |
|
| 71 | +async def test_async_await_with_chat_context(m_session): |
| 72 | + m_session.ctx = ChatContext() |
| 73 | + |
| 74 | + m1 = Message(role="user", content="1") |
| 75 | + m2 = Message(role="user", content="2") |
| 76 | + r1 = await m_session.aact(m1) |
| 77 | + r2 = await m_session.aact(m2) |
| 78 | + |
| 79 | + # This should be the order of these items in the session's context. |
| 80 | + history = [r2, m2, r1, m1] |
| 81 | + |
| 82 | + ctx = m_session.ctx |
| 83 | + for i in range(len(history)): |
| 84 | + assert ctx.node_data is history[i] |
| 85 | + ctx = ctx.previous_node |
| 86 | + |
| 87 | + # Ensure we made it back to the root. |
| 88 | + assert ctx.is_root_node == True |
| 89 | + |
| 90 | +async def test_async_without_waiting_with_chat_context(m_session): |
| 91 | + m_session.ctx = ChatContext() |
| 92 | + |
| 93 | + m1 = Message(role="user", content="1") |
| 94 | + m2 = Message(role="user", content="2") |
| 95 | + co1 = m_session.aact(m1) |
| 96 | + co2 = m_session.aact(m2) |
| 97 | + _, _ = await asyncio.gather(co2, co1) |
| 98 | + |
| 99 | + ctx = m_session.ctx |
| 100 | + assert len(ctx.view_for_generation()) == 2 |
| 101 | + |
| 102 | + |
69 | 103 | if __name__ == "__main__": |
70 | 104 | pytest.main([__file__]) |
0 commit comments