Skip to content

Commit d766a3e

Browse files
committed
add fx to get actions for tool use for contexts; add test
1 parent c8ea5ce commit d766a3e

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

mellea/stdlib/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,14 @@ def render_for_generation(self) -> list[Component | CBlock] | None:
161161
"""Provides a linear list of context components to use for generation, or None if that is not possible to construct."""
162162
...
163163

164+
@abc.abstractmethod
165+
def actions_for_available_tools(self) -> list[Component | CBlock] | None:
166+
"""Provides a list of actions to extract tools from for use with during generation, or None if that's not possible.
167+
168+
Can be used to make the available tools differ from the tools of all the actions in the context.
169+
"""
170+
...
171+
164172
@abc.abstractmethod
165173
def full_event_log(self) -> list[Component | CBlock]:
166174
"""Provides a list of all events stored in the context."""
@@ -210,6 +218,14 @@ def __init__(self):
210218
self._ctx = []
211219
self._log_ctx = []
212220

221+
def actions_for_available_tools(self) -> list[Component | CBlock] | None:
222+
"""Provides a list of actions to extract tools from for use with during generation, or None if that's not possible.
223+
224+
Can be used to make the available tools differ from the tools of all the actions in the context.
225+
In most cases, this will just be the same context as `render_for_generation`.
226+
"""
227+
return self.render_for_generation()
228+
213229
def last_output(self):
214230
"""The last output thunk of the context."""
215231
for c in self._ctx[::-1]:

test/stdlib_basics/test_base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
from mellea.stdlib.base import Component, CBlock
23
from mellea.stdlib.base import LinearContext
34

@@ -33,3 +34,21 @@ def test_context():
3334
ctx.insert(CBlock("b"))
3435
ctx.insert(CBlock("c"))
3536
ctx.insert(CBlock("d"))
37+
38+
39+
def test_actions_for_available_tools():
40+
ctx = LinearContext(window_size=3)
41+
ctx.insert(CBlock("a"))
42+
ctx.insert(CBlock("b"))
43+
for_generation = ctx.render_for_generation()
44+
assert for_generation is not None
45+
46+
actions = ctx.actions_for_available_tools()
47+
assert actions is not None
48+
49+
assert len(for_generation) == len(actions)
50+
for i in range(len(actions)):
51+
assert actions[i] == for_generation[i]
52+
53+
if __name__ == "__main__":
54+
pytest.main([__file__])

0 commit comments

Comments
 (0)