Skip to content

Commit 22f9cea

Browse files
committed
new: tools can now be directly called from the prompt by interpolation
1 parent 5df321e commit 22f9cea

File tree

10 files changed

+171
-11
lines changed

10 files changed

+171
-11
lines changed

examples/changelog/agent.yml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@ agent: >
33
44
## Guidance
55
6-
- Start by using the get_new_commits tool to get a list of commits since the last release.
7-
- Then use the create_changelog tool to generate a changelog for the new commits.
8-
- Focus on the new features and fixes, group and summarize other changes into "Miscellaneous" or "Other".
6+
- Use the create_changelog tool to generate a changelog for the new commits.
7+
- Focus on the new features and major fixes.
8+
- Group and summarize other minor changes into "Miscellaneous" or "Other".
99
- Add relevant and catchy emojis but ONLY to important changes.
1010
11-
task: Generate a changelog for the current project.
11+
task: >
12+
Generate a changelog for these commits in the current project:
13+
14+
{{ get_new_commits() }}
1215
1316
defaults:
1417
output: CHANGELOG.md

nerve/models.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ class Configuration(BaseModel):
6464
# legacy field used to detect if the user is loading a legacy file
6565
system_prompt: str | None = Field(default=None, exclude=True)
6666

67-
# TODO: add a "precall" to call the tool when the agent is started
68-
6967
# TODO: document these fields.
7068

7169
# used for versioning the agents

nerve/runtime/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from loguru import logger
55

66
from nerve.models import Tool
7+
from nerve.runtime import state
78
from nerve.tools import compiler
89

910

@@ -62,6 +63,6 @@ def build(
6263
logger.debug(f"🧰 importing {len(funcs)} custom tools from functions")
6364
runtime.tools.extend(funcs)
6465

65-
logger.debug(f"tools: {runtime.tools}")
66+
state.set_tools({tool.__name__: tool for tool in runtime.tools})
6667

6768
return runtime

nerve/runtime/state.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import json
23
import os
34
import pathlib
5+
import threading
46
import typing as t
57

68
import click
@@ -28,6 +30,8 @@
2830
_defaults: dict[str, t.Any] = {}
2931
# similar to variables but used by tools
3032
_knowledge: dict[str, t.Any] = {}
33+
# tools
34+
_tools: dict[str, t.Callable[..., t.Any]] = {}
3135
# extra tools defined at runtime
3236
_extra_tools: dict[str, t.Callable[..., t.Any]] = {}
3337
# listeners for events
@@ -222,6 +226,15 @@ def as_dict() -> dict[str, t.Any]:
222226
}
223227

224228

229+
def set_tools(tools: dict[str, t.Callable[..., t.Any]]) -> None:
230+
"""Set all tools."""
231+
232+
global _tools
233+
_tools = tools
234+
235+
logger.debug(f"tools: {_tools}")
236+
237+
225238
def get_extra_tools() -> dict[str, t.Callable[..., t.Any]]:
226239
"""Get any extra tool registered at runtime."""
227240

@@ -329,9 +342,8 @@ def on_user_input_needed(input_name: str, prompt: str) -> str:
329342
)
330343

331344

332-
def interpolate(raw: str, extra: dict[str, t.Any] | None = None) -> str:
333-
"""Interpolate the current state into a string."""
334-
345+
def _create_jinja_env() -> jinja2.Environment:
346+
# we use this to catch undefined variables at runtime
335347
class OnUndefinedVariable(jinja2.Undefined):
336348
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
337349
super().__init__(*args, **kwargs)
@@ -352,5 +364,46 @@ def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
352364
def __str__(self) -> str:
353365
return self.value or "<UNDEFINED>"
354366

367+
env = jinja2.Environment(undefined=OnUndefinedVariable)
368+
369+
# allow prompts to call tools
370+
for name, tool_fn in _tools.items():
371+
# if the tool is async, wrap it in a sync function to make it callable from Jinja
372+
if asyncio.iscoroutinefunction(tool_fn):
373+
374+
def make_sync_wrapper(fn: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]:
375+
def sync_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
376+
coro = fn(*args, **kwargs)
377+
# NOTE: we use a list to store the result because nonlocal variables (like scalars) inside a nested
378+
# function do not allow assignment unless explicitly declared nonlocal. However, mutable objects like
379+
# lists can be modified within the nested function without extra declarations.
380+
result_container = []
381+
382+
def we_need_an_async_loop_thread() -> None:
383+
loop = asyncio.new_event_loop()
384+
asyncio.set_event_loop(loop)
385+
result_container.append(loop.run_until_complete(coro))
386+
387+
thread = threading.Thread(target=we_need_an_async_loop_thread)
388+
thread.start()
389+
thread.join()
390+
391+
return result_container[0]
392+
393+
return sync_wrapper
394+
395+
env.globals[name] = make_sync_wrapper(tool_fn)
396+
else:
397+
env.globals[name] = tool_fn
398+
399+
return env
400+
401+
402+
def interpolate(raw: str, extra: dict[str, t.Any] | None = None) -> str:
403+
"""Interpolate the current state into a string."""
404+
405+
env = _create_jinja_env()
406+
template = env.from_string(raw)
355407
context = _variables | (extra or {})
356-
return jinja2.Environment(undefined=OnUndefinedVariable).from_string(raw).render(**context)
408+
409+
return template.render(**context)

nerve/runtime/state_test.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import os
2+
from unittest.mock import MagicMock, patch
3+
4+
import pytest
5+
6+
from nerve.models import Mode
7+
from nerve.runtime import state
8+
9+
10+
class TestInterpolate:
11+
def setup_method(self) -> None:
12+
# Reset state before each test
13+
state.reset()
14+
state._variables = {}
15+
state._defaults = {}
16+
state._mode = Mode.AUTOMATIC
17+
state._tools = {}
18+
19+
def test_basic_interpolation(self) -> None:
20+
# Setup
21+
state.update_variables({"name": "John", "age": 30})
22+
23+
# Test
24+
result = state.interpolate("Hello, {{ name }}! You are {{ age }} years old.")
25+
26+
# Verify
27+
assert result == "Hello, John! You are 30 years old."
28+
29+
def test_interpolation_with_extra_context(self) -> None:
30+
# Setup
31+
state.update_variables({"name": "John"})
32+
extra = {"location": "New York"}
33+
34+
# Test
35+
result = state.interpolate("{{ name }} is from {{ location }}.", extra)
36+
37+
# Verify
38+
assert result == "John is from New York."
39+
40+
def test_undefined_variable_in_automatic_mode(self) -> None:
41+
# Setup
42+
state._mode = Mode.AUTOMATIC
43+
state._defaults = {"city": "San Francisco"}
44+
45+
# Test
46+
result = state.interpolate("Welcome to {{ city }}!")
47+
48+
# Verify
49+
assert result == "Welcome to San Francisco!"
50+
51+
def test_undefined_variable_from_environment(self) -> None:
52+
# Setup
53+
with patch.dict(os.environ, {"TEST_VAR": "test_value"}):
54+
# Test
55+
result = state.interpolate("Environment value: {{ TEST_VAR }}")
56+
57+
# Verify
58+
assert result == "Environment value: test_value"
59+
# The variable should also be saved to state
60+
assert state._variables.get("TEST_VAR") == "test_value"
61+
62+
@patch("builtins.input", return_value="Paris")
63+
def test_undefined_variable_in_interactive_mode(self, mock_input: MagicMock) -> None:
64+
# Setup
65+
state._mode = Mode.INTERACTIVE
66+
67+
# Test
68+
result = state.interpolate("Welcome to {{ city }}!")
69+
70+
# Verify
71+
assert result == "Welcome to Paris!"
72+
assert mock_input.called
73+
assert state._variables.get("city") == "Paris"
74+
75+
def test_tool_call_in_template(self) -> None:
76+
# Setup
77+
mock_tool = MagicMock(return_value="tool result")
78+
state._tools = {"test_tool": mock_tool}
79+
80+
# Test
81+
result = state.interpolate("Tool output: {{ test_tool() }}")
82+
83+
# Verify
84+
assert result == "Tool output: tool result"
85+
assert mock_tool.called
86+
87+
@patch("nerve.runtime.state.on_user_input_needed")
88+
def test_missing_variable_raises_exception(self, mock_input_needed: MagicMock) -> None:
89+
# Setup
90+
state._mode = Mode.AUTOMATIC
91+
mock_input_needed.side_effect = Exception("Missing parameter")
92+
93+
# Test & Verify
94+
with pytest.raises(Exception): # noqa: B017
95+
state.interpolate("Missing {{ variable }}")
96+
97+
def test_complex_expression(self) -> None:
98+
# Setup
99+
state.update_variables({"numbers": [1, 2, 3, 4, 5]})
100+
101+
# Test
102+
result = state.interpolate("Sum: {{ numbers|sum }}, Average: {{ (numbers|sum) / numbers|length }}")
103+
104+
# Verify
105+
assert result == "Sum: 15, Average: 3.0"

0 commit comments

Comments
 (0)