Skip to content

Commit f261d46

Browse files
splitting ctx.linearize() into two more fitting functions:
1) `ctx.render_for_generation()` used in generate calls 2) `ctx.full_event_log()` used for views on the whole context also: fixed and inherent bug in `m.chat` which wrongly added the chat turn to the contexts (as single events, not as turn).
1 parent 36ad134 commit f261d46

File tree

13 files changed

+42
-29
lines changed

13 files changed

+42
-29
lines changed

docs/examples/agents/react.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def react(
103103
react_toolbox: ReactToolbox,
104104
):
105105
assert m.ctx.is_chat_context, "ReACT requires a chat context."
106-
test_ctx_lin = m.ctx.linearize()
106+
test_ctx_lin = m.ctx.render_for_generation()
107107
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
108108
"ReACT expects a fresh context."
109109
)

docs/examples/agents/react_instruct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def react(
101101
react_toolbox: ReactToolbox,
102102
):
103103
assert m.ctx.is_chat_context, "ReACT requires a chat context."
104-
test_ctx_lin = m.ctx.linearize()
104+
test_ctx_lin = m.ctx.render_for_generation()
105105
assert test_ctx_lin is not None and len(test_ctx_lin) == 0, (
106106
"ReACT expects a fresh context."
107107
)

docs/tutorial.md

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -968,15 +968,15 @@ Let's look at how this agent is implemented in Mellea:
968968
```python
969969
# file: https://github.com/generative-computing/mellea/blob/main/docs/examples/agents/react.py#L99
970970
def react(
971-
m: mellea.MelleaSession,
972-
goal: str,
973-
react_toolbox: ReactToolbox,
974-
budget: int=5,
971+
m: mellea.MelleaSession,
972+
goal: str,
973+
react_toolbox: ReactToolbox,
974+
budget: int = 5,
975975
):
976976
assert m.ctx.is_chat_context, "ReACT requires a chat context."
977-
test_ctx_lin = m.ctx.linearize()
977+
test_ctx_lin = m.ctx.render_for_generation()
978978
assert (
979-
test_ctx_lin is not None and len(test_ctx_lin) == 0
979+
test_ctx_lin is not None and len(test_ctx_lin) == 0
980980
), "ReACT expects a fresh context."
981981

982982
# Construct the system prompt for ReACT.
@@ -1006,15 +1006,17 @@ def react(
10061006
# model_options={mellea.backends.types.ModelOption.TOOLS: react_toolbox.tools_dict()},
10071007
format=react_toolbox.tool_name_schema(),
10081008
)
1009-
selected_tool: ReactTool = react_toolbox.get_tool_from_schema(act.content)
1009+
selected_tool: ReactTool = react_toolbox.get_tool_from_schema(
1010+
act.content)
10101011
print(selected_tool.get_name())
10111012

10121013
print(f"### Arguments for action")
10131014
act_args = m.chat(
10141015
"Choose arguments for the tool. Respond using JSON and include only the tool arguments in your response.",
10151016
format=selected_tool.args_schema(),
10161017
)
1017-
print(f"```json\n{json.dumps(json.loads(act_args.content), indent=2)}\n```")
1018+
print(
1019+
f"```json\n{json.dumps(json.loads(act_args.content), indent=2)}\n```")
10181020

10191021
# TODO: handle exceptions.
10201022
print("### Observation")

mellea/backends/formatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def print_context(self, ctx: Context) -> str:
166166
)
167167
match ctx:
168168
case LinearContext():
169-
linearized_ctx = ctx.linearize()
169+
linearized_ctx = ctx.render_for_generation()
170170
assert linearized_ctx is not None
171171
return "".join([self.print(x) for x in linearized_ctx])
172172
case SimpleContext():

mellea/backends/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def _generate_from_context_alora(
239239
"This code block should not execute unless there is a 'constraint' alora loaded."
240240
)
241241
# Construct the linearized context. This is very similar to normal generation.
242-
linearized_ctx = ctx.linearize()
242+
linearized_ctx = ctx.render_for_generation()
243243
assert linearized_ctx is not None and len(linearized_ctx) > 1
244244
msgs = self.formatter.to_chat_messages(linearized_ctx)
245245
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -286,7 +286,7 @@ def _generate_from_context_standard(
286286
# Otherwise, we will linearize the context and treat it as a raw input.
287287
decoded_result: str | None = None
288288
if ctx.is_chat_context:
289-
linearized_ctx = ctx.linearize()
289+
linearized_ctx = ctx.render_for_generation()
290290
assert linearized_ctx is not None, (
291291
"If ctx.is_chat_context, then the context should be linearizable."
292292
)

mellea/backends/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def generate_from_chat_context(
263263
"""
264264
model_opts = self._simplify_and_merge(model_options)
265265

266-
linearized_context = ctx.linearize()
266+
linearized_context = ctx.render_for_generation()
267267
assert linearized_context is not None, (
268268
"Cannot generate from a non-linear context in a FormatterBackend."
269269
)

mellea/backends/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _generate_from_chat_context_alora(
327327
)
328328

329329
# Construct the linearized context. This is very similar to normal generation.
330-
linearized_ctx = ctx.linearize()
330+
linearized_ctx = ctx.render_for_generation()
331331
assert linearized_ctx is not None and len(linearized_ctx) > 1
332332
msgs = self.formatter.to_chat_messages(linearized_ctx)
333333
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -362,7 +362,7 @@ def _generate_from_chat_context_standard(
362362
model_opts = self._simplify_and_merge(
363363
model_options, is_chat_context=ctx.is_chat_context
364364
)
365-
linearized_context = ctx.linearize()
365+
linearized_context = ctx.render_for_generation()
366366
assert linearized_context is not None, (
367367
"Cannot generate from a non-linear context in a FormatterBackend."
368368
)

mellea/backends/watsonx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def generate_from_chat_context(
220220
model_options, is_chat_context=ctx.is_chat_context
221221
)
222222

223-
linearized_context = ctx.linearize()
223+
linearized_context = ctx.render_for_generation()
224224
assert linearized_context is not None, (
225225
"Cannot generate from a non-linear context in a FormatterBackend."
226226
)

mellea/stdlib/base.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,13 @@ def _hash_for_kv_cache(self):
157157
...
158158

159159
@abc.abstractmethod
160-
def linearize(self) -> list[Component | CBlock] | None:
161-
"""Provides a linear list of context components. This is not always possible, or None if that is not possible to construct."""
160+
def render_for_generation(self) -> list[Component | CBlock] | None:
161+
"""Provides a linear list of context components to use for generation, or None if that is not possible to construct."""
162+
...
163+
164+
@abc.abstractmethod
165+
def full_event_log(self) -> list[Component | CBlock]:
166+
"""Provides a list of all events stored in the context."""
162167
...
163168

164169
@abc.abstractmethod
@@ -262,6 +267,10 @@ def last_output_and_logs(
262267
)
263268
return last, log[0]
264269

270+
def full_event_log(self) -> list[Component | CBlock]:
271+
"""Returns the underlying _ctx."""
272+
return self._ctx
273+
265274
def last_turn(self):
266275
"""The last input/output turn of the context."""
267276
if len(self._ctx) == 0:
@@ -335,8 +344,8 @@ def insert_turn(
335344
if turn.output:
336345
self.insert(turn.output, generate_logs=generate_logs)
337346

338-
def linearize(self) -> list[Component | CBlock] | None:
339-
"""Returns the underlying _ctx list."""
347+
def render_for_generation(self) -> list[Component | CBlock] | None:
348+
"""Returns the underlying _ctx list for generation."""
340349
return self._ctx
341350

342351
def is_chat_history(self):
@@ -372,7 +381,7 @@ def __init__(self):
372381
super().__init__()
373382
self.is_chat_context = True
374383

375-
def linearize(self) -> list[Component | CBlock] | None:
384+
def render_for_generation(self) -> list[Component | CBlock] | None:
376385
"""Uses _ctx ordering."""
377386
return []
378387

mellea/stdlib/chat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,10 @@ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None:
111111
case _:
112112
return None
113113

114-
linearized_ctx = ctx.linearize()
115-
if linearized_ctx is None:
114+
all_ctx_events = ctx.full_event_log()
115+
if all_ctx_events is None:
116116
raise Exception("Trying to cast a non-linear history into a chat history.")
117117
else:
118-
history = [_to_msg(c) for c in linearized_ctx]
118+
history = [_to_msg(c) for c in all_ctx_events]
119119
assert None not in history, "Could not render this context as a chat history."
120120
return history # type: ignore

0 commit comments

Comments
 (0)