Skip to content

Commit 72be826

Browse files
committed
fix: add async session tests with chat context
1 parent 2ab8b0d commit 72be826

File tree

1 file changed

+36
-2
lines changed

1 file changed

+36
-2
lines changed

test/stdlib_basics/test_session.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1+
import asyncio
12
import os
23

34
import pytest
45

56
from mellea.backends.types import ModelOption
6-
from mellea.stdlib.base import ModelOutputThunk
7+
from mellea.stdlib.base import ChatContext, ModelOutputThunk
78
from mellea.stdlib.chat import Message
89
from mellea.stdlib.session import start_session
910

10-
@pytest.fixture(scope="module")
11+
# We edit the context type in the async tests below. Don't change the scope here.
12+
@pytest.fixture(scope="function")
1113
def m_session(gh_run):
1214
if gh_run == 1:
1315
m = start_session(
@@ -66,5 +68,37 @@ async def test_ainstruct(m_session):
6668
assert m_session.ctx is not initial_ctx
6769
assert out.value is not None
6870

71+
async def test_async_await_with_chat_context(m_session):
72+
m_session.ctx = ChatContext()
73+
74+
m1 = Message(role="user", content="1")
75+
m2 = Message(role="user", content="2")
76+
r1 = await m_session.aact(m1)
77+
r2 = await m_session.aact(m2)
78+
79+
# This should be the order of these items in the session's context.
80+
history = [r2, m2, r1, m1]
81+
82+
ctx = m_session.ctx
83+
for i in range(len(history)):
84+
assert ctx.node_data is history[i]
85+
ctx = ctx.previous_node
86+
87+
# Ensure we made it back to the root.
88+
assert ctx.is_root_node == True
89+
90+
async def test_async_without_waiting_with_chat_context(m_session):
91+
m_session.ctx = ChatContext()
92+
93+
m1 = Message(role="user", content="1")
94+
m2 = Message(role="user", content="2")
95+
co1 = m_session.aact(m1)
96+
co2 = m_session.aact(m2)
97+
_, _ = await asyncio.gather(co2, co1)
98+
99+
ctx = m_session.ctx
100+
assert len(ctx.view_for_generation()) == 2
101+
102+
69103
if __name__ == "__main__":
70104
pytest.main([__file__])

0 commit comments

Comments
 (0)