Skip to content

Commit 824b24c

Browse files
nrfultonGitHub Enterprise
authored andcommitted
Adds as_chat_history
We should move this helper into `Context` in a future sprint.
1 parent defe6eb commit 824b24c

File tree

3 files changed

+60
-1
lines changed

3 files changed

+60
-1
lines changed

mellea/stdlib/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,9 @@ def linearize(self) -> list[Component | CBlock] | None:
340340
return self._ctx
341341

342342
def is_chat_history(self):
343+
FancyLogger.get_logger().warning(
344+
"is_chat_history doesn't work properly, because ModelOutputThunks are not Messages."
345+
)
343346
"""Returns true if everything in the LinearContext is a chat `Message`."""
344347
return all(
345348
str(type(x)) == "Message" for x in self._ctx

mellea/stdlib/chat.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,15 @@
33
from collections.abc import Mapping
44
from typing import Any, Literal
55

6-
from mellea.stdlib.base import Component, ModelToolCall, TemplateRepresentation
6+
from mellea.helpers.fancy_logger import FancyLogger
7+
from mellea.stdlib.base import (
8+
CBlock,
9+
Component,
10+
Context,
11+
ModelOutputThunk,
12+
ModelToolCall,
13+
TemplateRepresentation,
14+
)
715

816

917
class Message(Component):
@@ -85,3 +93,28 @@ def format_for_llm(self) -> TemplateRepresentation:
8593
def __str__(self):
8694
"""Pretty representation of messages, because they are a special case."""
8795
return f'mellea.Message(role="{self.role}", content="{self.content}", name="{self.name}")'
96+
97+
98+
def as_chat_history(ctx: Context) -> list[Message]:
99+
"""Returns a list of Messages corresponding to a Context."""
100+
101+
def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None:
102+
match c:
103+
case Message():
104+
return c
105+
case ModelOutputThunk():
106+
match c.parsed_repr:
107+
case Message():
108+
return c.parsed_repr
109+
case _:
110+
return None
111+
case _:
112+
return None
113+
114+
linearized_ctx = ctx.linearize()
115+
if linearized_ctx is None:
116+
raise Exception("Trying to cast a non-linear history into a chat history.")
117+
else:
118+
history = [_to_msg(c) for c in linearized_ctx]
119+
assert None not in history, "Could not render this context as a chat history."
120+
return history # type: ignore
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from mellea.stdlib.base import ModelOutputThunk, LinearContext
3+
from mellea.stdlib.chat import as_chat_history, Message
4+
from mellea.stdlib.session import start_session
5+
6+
7+
def test_chat_view_linear_ctx():
8+
m = start_session(ctx=LinearContext())
9+
m.chat("What is 1+1?")
10+
m.chat("What is 2+2?")
11+
assert len(as_chat_history(m.ctx)) == 4
12+
assert all([type(x) == Message for x in as_chat_history(m.ctx)])
13+
14+
@pytest.mark.skip("linearize() returns [] for a SimpleContext... that's going to be annoying.")
15+
def test_chat_view_simple_ctx():
16+
m = start_session()
17+
m.chat("What is 1+1?")
18+
m.chat("What is 2+2?")
19+
assert len(as_chat_history(m.ctx)) == 2
20+
assert all([type(x) == Message for x in as_chat_history(m.ctx)])
21+
22+
if __name__ == "__main__":
23+
pytest.main([__file__])

0 commit comments

Comments
 (0)