Skip to content

Commit b956f8d

Browse files
committed
feat: add clone method to session; remove refs to model_opts in session
1 parent 72be826 commit b956f8d

File tree

3 files changed

+112
-15
lines changed

3 files changed

+112
-15
lines changed

docs/tutorial.md

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,23 @@ or the entire last turn (user query + assistant response):
944944
print(m.ctx.last_turn())
945945
```
946946

947+
You can also use `session.clone()` to create a copy of a given session with its context at given point in time. This allows you to make multiple generation requests with the same objects in your context:
948+
```python
949+
m = start_session(ctx=ChatContext())
950+
m.instruct("Multiply 2x2.")
951+
952+
m1 = m.clone()
953+
m2 = m.clone()
954+
955+
# Need to run this code in an async event loop.
956+
co1 = m1.ainstruct("Multiply that by 3")
957+
co2 = m2.ainstruct("Multiply that by 5")
958+
959+
print(await co1) # 12
960+
print(await co2) # 20
961+
```
962+
In the above example, both requests have `Multiply 2x2` and the LLM's response to that (presumably `4`) in their context. By cloning the session, the new requests both operate independently on that context to get the correct answers to 4 x 3 and 4 x 5.
963+
947964
## Chapter 8: Implementing Agents
948965

949966
> **Definition:** An *agent* is a generative program in which an LLM determines the control flow of the program.
@@ -1323,13 +1340,13 @@ Mellea supports asynchronous behavior in several ways: asynchronous functions an
13231340

13241341
### Asynchronous Functions:
13251342
`MelleaSession`s have asynchronous functions that work just like regular async functions in python. These async session functions mirror their synchronous counterparts:
1326-
```
1343+
```python
13271344
m = start_session()
13281345
result = await m.ainstruct("Write your instruction here!")
13291346
```
13301347

13311348
However, if you want to run multiple async functions at the same time, you need to be careful with your context. By default, `MelleaSession`s use a `SimpleContext` that has no history. This will work just fine when running multiple async requests at once:
1332-
```
1349+
```python
13331350
m = start_session()
13341351
coroutines = []
13351352

@@ -1340,7 +1357,7 @@ results = await asyncio.gather(*coroutines)
13401357
```
13411358

13421359
If you try to use a `ChatContext`, you will need to await between each request so that the context can be properly modified:
1343-
```
1360+
```python
13441361
m = start_session(ctx=ChatContext())
13451362

13461363
result = await m.ainstruct("Write a short fairy tale.")
@@ -1351,7 +1368,7 @@ print(main_character)
13511368
```
13521369

13531370
Otherwise, you're requests will use outdated contexts that don't have the messages you expect. For example,
1354-
```
1371+
```python
13551372
m = start_session(ctx=ChatContext())
13561373

13571374
co1 = m.ainstruct("Write a very long math problem.") # Start first request.
@@ -1360,8 +1377,12 @@ co2 = m.ainstruct("Solve the math problem.") # Start second request with an emp
13601377
results = await asyncio.gather(co1, co2)
13611378
for result in results:
13621379
print(result) # Neither request had anything in its context.
1380+
1381+
print(m.ctx) # Only shows the operations from the second request.
13631382
```
13641383

1384+
Additionally, see [Chapter 7: Context Management](#chapter-7-on-context-management) for an example of how to use `session.clone()` to avoid these context issues.
1385+
13651386
### Asynchronicity in Synchronous Functions
13661387
Mellea utilizes asynchronicity internally. When you call `m.instruct`, you are using synchronous code that executes an asynchronous request to an LLM to generate the result. For a single request, this won't cause any differences in execution speed.
13671388

mellea/stdlib/session.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import contextvars
6+
from copy import copy
67
from typing import Any, Literal, overload
78

89
from PIL import Image as PILImage
@@ -176,11 +177,10 @@ def __init__(self, backend: Backend, ctx: Context | None = None):
176177
Args:
177178
backend (Backend): This is always required.
178179
ctx (Context): The way in which the model's context will be managed. By default, each interaction with the model is a stand-alone interaction, so we use SimpleContext as the default.
179-
model_options (Optional[dict]): model options, which will upsert into the model/backend's defaults.
180180
"""
181181
self.backend = backend
182182
self.ctx: Context = ctx if ctx is not None else SimpleContext()
183-
self._backend_stack: list[tuple[Backend, dict | None]] = []
183+
self._backend_stack: list[Backend] = []
184184
self._session_logger = FancyLogger.get_logger()
185185
self._context_token = None
186186

@@ -196,14 +196,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
196196
_context_session.reset(self._context_token)
197197
self._context_token = None
198198

199-
def _push_model_state(self, new_backend: Backend, new_model_opts: dict):
200-
"""The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`.
201-
202-
Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class.
203-
"""
204-
self._backend_stack.append((self.backend, self.model_options))
199+
def _push_model_state(self, new_backend: Backend):
200+
"""The backend used within a `Context` can be temporarily changed. This method changes the model's backend, while saving the current settings in the `self._backend_stack`."""
201+
self._backend_stack.append(self.backend)
205202
self.backend = new_backend
206-
self.opts = new_model_opts
207203

208204
def _pop_model_state(self) -> bool:
209205
"""Pops the model state.
@@ -214,13 +210,43 @@ def _pop_model_state(self) -> bool:
214210
Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class.
215211
"""
216212
try:
217-
b, b_model_opts = self._backend_stack.pop()
213+
b = self._backend_stack.pop()
218214
self.backend = b
219-
self.model_options = b_model_opts
220215
return True
221216
except Exception:
222217
return False
223218

219+
def __copy__(self):
220+
new = MelleaSession(backend=self.backend, ctx=self.ctx)
221+
new._backend_stack = self._backend_stack.copy()
222+
new._session_logger = self._session_logger
223+
# Explicitly don't copy over the _context_token.
224+
225+
return new
226+
227+
def clone(self):
228+
"""Useful for running multiple generation requests while keeping the context at a given point in time.
229+
230+
Returns:
231+
a copy of the current session. Keeps the context, backend, backend stack, and session logger.
232+
233+
Examples:
234+
>>> from mellea import start_session
235+
>>> m = start_session()
236+
>>> m.instruct("What is 2x2?")
237+
>>>
238+
>>> m1 = m.clone()
239+
>>> out = m1.instruct("Multiply that by 2")
240+
>>> print(out)
241+
... 8
242+
>>>
243+
>>> m2 = m.clone()
244+
>>> out = m2.instruct("Multiply that by 3")
245+
>>> print(out)
246+
... 12
247+
"""
248+
return copy(self)
249+
224250
def reset(self):
225251
"""Reset the context state."""
226252
self.ctx = self.ctx.reset_to_new()

test/stdlib_basics/test_session.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import pytest
55

6+
from mellea.backends.ollama import OllamaModelBackend
67
from mellea.backends.types import ModelOption
78
from mellea.stdlib.base import ChatContext, ModelOutputThunk
89
from mellea.stdlib.chat import Message
@@ -99,6 +100,55 @@ async def test_async_without_waiting_with_chat_context(m_session):
99100
ctx = m_session.ctx
100101
assert len(ctx.view_for_generation()) == 2
101102

103+
def test_session_copy_with_context_ops(m_session):
104+
out = m_session.instruct("What is 2x2?")
105+
main_ctx = m_session.ctx
106+
107+
m1 = m_session.clone()
108+
out1 = m1.instruct("Multiply by 3.")
109+
110+
m2 = m_session.clone()
111+
out2 = m2.instruct("Multiply by 4.")
112+
113+
# Assert that each context is the correct one.
114+
assert m_session.ctx is main_ctx
115+
assert m_session.ctx is not m1.ctx
116+
assert m_session.ctx is not m2.ctx
117+
assert m1.ctx is not m2.ctx
118+
119+
# Assert that node data is correct.
120+
assert m_session.ctx.node_data is out
121+
assert m1.ctx.node_data is out1
122+
assert m2.ctx.node_data is out2
123+
124+
# Assert that the new sessions still branch off the original one.
125+
assert m1.ctx.previous_node.previous_node is m_session.ctx
126+
assert m2.ctx.previous_node.previous_node is m_session.ctx
127+
128+
def test_session_copy_with_backend_stack(m_session):
129+
# Assert expected values from cloning.
130+
m1 = m_session.clone()
131+
assert m1.backend is m_session.backend
132+
assert m1._session_logger is m_session._session_logger
133+
assert m1._backend_stack is not m_session._backend_stack
134+
135+
# Assert that pushing to a backend stack doesn't change it for sessions previously cloned from it.
136+
new_backend = OllamaModelBackend()
137+
m_session._push_model_state(new_backend=new_backend)
138+
assert len(m_session._backend_stack) == 1
139+
assert len(m1._backend_stack) == 0
140+
assert m1.backend is not m_session.backend
141+
142+
# Assert that newly cloned sessions don't cause errors with changes to the backend stack.
143+
m2 = m_session.clone()
144+
assert len(m2._backend_stack) == 1
145+
146+
# They should still be different lists.
147+
assert m2._backend_stack is not m_session._backend_stack
148+
assert m2._pop_model_state()
149+
assert len(m2._backend_stack) == 0
150+
assert len(m_session._backend_stack) == 1
151+
assert m2.backend is m1.backend
102152

103153
if __name__ == "__main__":
104154
pytest.main([__file__])

0 commit comments

Comments
 (0)