Skip to content

Commit 0754a5d

Browse files
Renaming:
ctx.as_list(last_n_components..) ctx.render_for_generation() --> ctx.view_for_generation()
1 parent c3854f4 commit 0754a5d

File tree

7 files changed

+30
-25
lines changed

7 files changed

+30
-25
lines changed

mellea/backends/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def _generate_from_context_alora(
239239
"This code block should not execute unless there is a 'constraint' alora loaded."
240240
)
241241
# Construct the linearized context. This is very similar to normal generation.
242-
linearized_ctx = ctx.render_for_generation()
242+
linearized_ctx = ctx.view_for_generation()
243243
assert linearized_ctx is not None and len(linearized_ctx) > 1
244244
msgs = self.formatter.to_chat_messages(linearized_ctx)
245245
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -278,7 +278,7 @@ def _generate_from_context_standard(
278278
# If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template.
279279
# Otherwise, we will linearize the context and treat it as a raw input.
280280
if ctx.is_chat_context:
281-
linearized_ctx = ctx.render_for_generation()
281+
linearized_ctx = ctx.view_for_generation()
282282
assert linearized_ctx is not None, (
283283
"If ctx.is_chat_context, then the context should be linearizable."
284284
)

mellea/backends/litellm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def _generate_from_chat_context_standard(
217217
tool_calls: bool = False,
218218
) -> ModelOutputThunk:
219219
model_opts = self._simplify_and_merge(model_options)
220-
linearized_context = ctx.render_for_generation()
220+
linearized_context = ctx.view_for_generation()
221221
assert linearized_context is not None, (
222222
"Cannot generate from a non-linear context in a FormatterBackend."
223223
)

mellea/backends/ollama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def generate_from_chat_context(
273273
"""
274274
model_opts = self._simplify_and_merge(model_options)
275275

276-
linearized_context = ctx.render_for_generation()
276+
linearized_context = ctx.view_for_generation()
277277
assert linearized_context is not None, (
278278
"Cannot generate from a non-linear context in a FormatterBackend."
279279
)

mellea/backends/openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ def _generate_from_chat_context_alora(
342342
)
343343

344344
# Construct the linearized context. This is very similar to normal generation.
345-
linearized_ctx = ctx.render_for_generation()
345+
linearized_ctx = ctx.view_for_generation()
346346
assert linearized_ctx is not None and len(linearized_ctx) > 1
347347
msgs = self.formatter.to_chat_messages(linearized_ctx)
348348
user_message, assistant_message = msgs[-2].content, msgs[-1].content
@@ -417,7 +417,7 @@ def _generate_from_chat_context_standard(
417417
model_opts = self._simplify_and_merge(
418418
model_options, is_chat_context=ctx.is_chat_context
419419
)
420-
linearized_context = ctx.render_for_generation()
420+
linearized_context = ctx.view_for_generation()
421421
assert linearized_context is not None, (
422422
"Cannot generate from a non-linear context in a FormatterBackend."
423423
)

mellea/backends/watsonx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def generate_from_chat_context(
242242
model_options, is_chat_context=ctx.is_chat_context
243243
)
244244

245-
linearized_context = ctx.render_for_generation()
245+
linearized_context = ctx.view_for_generation()
246246
assert linearized_context is not None, (
247247
"Cannot generate from a non-linear context in a FormatterBackend."
248248
)

mellea/stdlib/base.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -348,15 +348,18 @@ class ContextTurn:
348348

349349

350350
class Context(abc.ABC):
351-
"""A `Context` is used to track the state of a `MelleaSession`."""
351+
"""A `Context` is used to track the state of a `MelleaSession`.
352+
353+
A context is immutable. Every alteration leads to a new context.
354+
"""
352355

353356
_previous: Context | None
354357
_data: Component | CBlock | None
355358
_is_root: bool
356359
_is_chat_context: bool = True
357360

358361
def __init__(self):
359-
"""Constructs a new context."""
362+
"""Constructs a new root context with no content."""
360363
self._previous = None
361364
self._data = None
362365
self._is_root = True
@@ -403,18 +406,24 @@ def data(self) -> Component | CBlock | None:
403406
"""Returns the data associated with this context."""
404407
return self._data
405408

406-
def full_data_as_list(self) -> list[Component | CBlock]:
407-
"""Returns a list of all components in the context from root to current context."""
409+
def as_list(self, last_n_components: int | None = None) -> list[Component | CBlock]:
410+
"""Returns a list of the last n components in the context sorted from FIRST TO LAST.
411+
412+
If `last_n_elements` is `None`, then all components are returned."""
408413
context_list: list[Component | CBlock] = []
409414
current_context: Context = self
410415

411-
while not current_context.is_root:
416+
last_n_count = 0
417+
while not current_context.is_root and (
418+
last_n_components is None or last_n_count < last_n_components
419+
):
412420
data = current_context.data
413421
assert data is not None, "Data cannot be None (except for root context)."
414422
assert data not in context_list, (
415423
"There might be a cycle in the context tree. That is not allowed."
416424
)
417425
context_list.append(data)
426+
last_n_count += 1
418427

419428
current_context = current_context.previous # type: ignore
420429
assert current_context is not None, (
@@ -429,20 +438,20 @@ def actions_for_available_tools(self) -> list[Component | CBlock] | None:
429438
430439
Can be used to make the available tools differ from the tools of all the actions in the context. Can be overwritten by subclasses.
431440
"""
432-
return self.render_for_generation()
441+
return self.view_for_generation()
433442

434443
def last_output(self) -> ModelOutputThunk | None:
435444
"""The last output thunk of the context."""
436445

437-
for c in self.full_data_as_list()[::-1]:
446+
for c in self.as_list()[::-1]:
438447
if isinstance(c, ModelOutputThunk):
439448
return c
440449
return None
441450

442451
def last_turn(self):
443452
"""The last input/output turn of the context."""
444453

445-
history = self.full_data_as_list()
454+
history = self.as_list()
446455

447456
if len(history) == 0:
448457
return None
@@ -467,7 +476,7 @@ def add(self, c: Component | CBlock) -> Context:
467476
...
468477

469478
@abc.abstractmethod
470-
def render_for_generation(self) -> list[Component | CBlock] | None:
479+
def view_for_generation(self) -> list[Component | CBlock] | None:
471480
"""Provides a linear list of context components to use for generation, or None if that is not possible to construct."""
472481
...
473482

@@ -480,26 +489,22 @@ def __init__(self, *, window_size: int | None = None):
480489
super().__init__()
481490
self._window_size = window_size
482491

483-
def render_for_generation(self) -> list[Component | CBlock] | None:
484-
all_events = self.full_data_as_list()
485-
ws = self._window_size
486-
ws = ws if ws is not None else len(all_events)
487-
488-
return all_events[-ws:]
489-
490492
def add(self, c: Component | CBlock) -> ChatContext:
491493
new = ChatContext.from_previous(self, c)
492494
new._window_size = self._window_size
493495
return new
494496

497+
def view_for_generation(self) -> list[Component | CBlock] | None:
498+
return self.as_list(self._window_size)
499+
495500

496501
class SimpleContext(Context):
497502
"""A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved.."""
498503

499504
def add(self, c: Component | CBlock) -> SimpleContext:
500505
return SimpleContext.from_previous(self, c)
501506

502-
def render_for_generation(self) -> list[Component | CBlock] | None:
507+
def view_for_generation(self) -> list[Component | CBlock] | None:
503508
return []
504509

505510

mellea/stdlib/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None:
129129
case _:
130130
return None
131131

132-
all_ctx_events = ctx.full_data_as_list()
132+
all_ctx_events = ctx.as_list()
133133
if all_ctx_events is None:
134134
raise Exception("Trying to cast a non-linear history into a chat history.")
135135
else:

0 commit comments

Comments
 (0)