Skip to content

Commit 423485d

Browse files
committed
new: added support for reasoning models (closes #61)
1 parent a7e5cf9 commit 423485d

File tree

9 files changed

+634
-600
lines changed

9 files changed

+634
-600
lines changed

docs/concepts.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ using:
2323
- shell
2424
```
2525
26+
#### Enable Reasoning
27+
28+
For models supporting reasoning, you can add a `reasoning` field to enable it, with a value that can either be `low`, `medium` or `high`.
29+
2630
### Prompt Interpolation and Variables
2731
Nerve supports [Jinja2](https://jinja.palletsprojects.com/) templating for dynamic prompt construction. You can:
2832
- Inject command line arguments (`{{ target }}`)

docs/index.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ Run it with:
7070
nerve run new-agent --url cnn.com
7171
```
7272

73+
#### Enable Reasoning
74+
75+
For models supporting reasoning, you can add a `reasoning` field to enable it, with a value that can either be `low`, `medium` or `high`.
76+
7377
### Prompting & Variables
7478
Supports [Jinja2](https://jinja.palletsprojects.com/) templating. You can:
7579
- Include files: `{% include 'filename.md' %}`

nerve/generation/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import typing as t
55
from abc import ABC, abstractmethod
66

7+
from litellm import ConfigDict
8+
from pydantic import BaseModel
79
from loguru import logger
810

911
from nerve.models import Usage
1012
from nerve.runtime import state
1113
from nerve.tools.protocol import get_tool_response, get_tool_schema
1214

1315

14-
class WindowStrategy(ABC):
16+
class WindowStrategy(ABC, BaseModel):
1517
@abstractmethod
1618
async def get_window(self, history: list[dict[str, t.Any]]) -> list[dict[str, t.Any]]:
1719
pass
@@ -20,23 +22,28 @@ async def get_window(self, history: list[dict[str, t.Any]]) -> list[dict[str, t.
2022
def __str__(self) -> str:
2123
pass
2224

25+
class GenerationConfig(BaseModel):
26+
generator_id: str
27+
reasoning_effort: str | None = None
28+
window_strategy: WindowStrategy
29+
tools: list[t.Callable[..., t.Any]] | None = None
30+
2331

2432
class Engine(ABC):
2533
def __init__(
2634
self,
27-
generator_id: str,
28-
window_strategy: WindowStrategy,
29-
tools: list[t.Callable[..., t.Any]] | None = None,
35+
config: GenerationConfig,
3036
):
31-
self.generator_id = generator_id
37+
self.config = config
38+
self.generator_id = config.generator_id
3239
self.generator_params: dict[str, t.Any] = {}
3340

3441
self._parse_generator_params()
3542

3643
self.history: list[dict[str, t.Any]] = []
37-
self.window_strategy = window_strategy
44+
self.window_strategy = config.window_strategy
3845

39-
self.tools = {fn.__name__: fn for fn in (tools or [])}
46+
self.tools = {fn.__name__: fn for fn in (config.tools or [])}
4047
self.tools_schemas = []
4148
for tool_name, tool_fn in self.tools.items():
4249
if not tool_fn.__doc__:

nerve/generation/litellm.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import litellm
88
from loguru import logger
99

10-
from nerve.generation import Engine, WindowStrategy
10+
from nerve.generation import Engine, GenerationConfig, WindowStrategy
1111
from nerve.generation.conversation import SlidingWindowStrategy
1212
from nerve.generation.ollama import OllamaGlue
1313
from nerve.models import Usage
@@ -34,11 +34,9 @@ def _convert_to_serializable(obj: t.Any) -> t.Any:
3434
class LiteLLMEngine(Engine):
3535
def __init__(
3636
self,
37-
generator_id: str,
38-
window_strategy: WindowStrategy,
39-
tools: list[t.Callable[..., t.Any]] | None = None,
37+
config: GenerationConfig,
4038
):
41-
super().__init__(generator_id, window_strategy, tools)
39+
super().__init__(config)
4240

4341
# until this is not fixed, ollama needs special treatment: https://github.com/BerriAI/litellm/issues/6353
4442
self.is_ollama = "ollama" in self.generator_id
@@ -60,13 +58,20 @@ async def _litellm_generate(
6058
logger.debug(f"litellm.conversation: {json.dumps(conversation, indent=2)}")
6159

6260
# litellm.set_verbose = True
61+
62+
#if self.config.reasoning_effort:
63+
# # if the model does not support reasoning, avoid raising litellm.UnsupportedParamsError
64+
# # by dropping the unsupported parameter
65+
# litellm.drop_params = True
66+
6367
response = litellm.completion(
6468
model=self.generator_id,
6569
messages=conversation,
6670
tools=tools_schema,
6771
tool_choice="auto" if tools_schema else None,
6872
verbose=False,
6973
api_base=self.api_base,
74+
reasoning_effort=self.config.reasoning_effort,
7075
**self.generator_params,
7176
)
7277

nerve/models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class Argument(BaseModel):
7474
tool: str | None = None
7575

7676

77-
def _check_required_version(required: str | None) -> str | None:
77+
def _validate_required_version(required: str | None) -> str | None:
7878
if required:
7979
from packaging.requirements import Requirement
8080

@@ -98,6 +98,12 @@ def _check_required_version(required: str | None) -> str | None:
9898
return required
9999

100100

101+
def _validate_reasoning_effort(effort: str | None) -> str | None:
102+
if effort not in [None, "low", "medium", "high"]:
103+
raise ValueError(f"invalid reasoning effort: {effort}")
104+
return effort
105+
106+
101107
class Configuration(BaseModel):
102108
"""
103109
Configuration for an agent determining its "identity", task and capabilities.
@@ -127,13 +133,14 @@ class Limits(BaseModel):
127133

128134
# legacy field used to detect if the user is loading a legacy file
129135
system_prompt: str | None = Field(default=None, exclude=True)
130-
131136
# optional generator
132137
generator: str | None = None
138+
# thinking effort for models supporting reasoning
139+
reasoning: t.Annotated[str | None, AfterValidator(_validate_reasoning_effort)] = None
133140
# optional agent description
134141
description: str = ""
135142
# optional nerve version requirement
136-
requires: t.Annotated[str | None, AfterValidator(_check_required_version)] = None
143+
requires: t.Annotated[str | None, AfterValidator(_validate_required_version)] = None
137144
# used for versioning the agents
138145
version: str = "1.0.0"
139146
# the system prompt, the agent identity

nerve/runtime/agent.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from loguru import logger
55

66
import nerve.runtime.state as state
7-
from nerve.generation import Engine, WindowStrategy
7+
from nerve.generation import Engine, GenerationConfig, WindowStrategy
88
from nerve.generation.conversation import FullHistoryStrategy
99
from nerve.generation.litellm import LiteLLMEngine
1010
from nerve.models import Configuration, Tool, Usage
@@ -77,10 +77,17 @@ async def create(
7777
configuration=configuration,
7878
)
7979

80+
engine_config = GenerationConfig(
81+
generator_id=configuration.generator,
82+
window_strategy=window_strategy,
83+
tools=runtime.tools,
84+
reasoning_effort=configuration.reasoning,
85+
)
86+
8087
return cls(
8188
runtime=runtime,
8289
configuration=configuration,
83-
generation_engine=LiteLLMEngine(configuration.generator, window_strategy, runtime.tools),
90+
generation_engine=LiteLLMEngine(engine_config),
8491
conv_window_strategy=window_strategy,
8592
)
8693

nerve/runtime/logging.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,14 @@ def log_event_to_terminal(event: Event) -> None:
9595
data["agent"] = DictWrapper(data["agent"])
9696

9797
generator = data["agent"].runtime.generator
98+
reasoning = data["agent"].configuration.reasoning
9899
name = data["agent"].runtime.name
99100
version = data["agent"].configuration.version
100101
tools = len(data["agent"].runtime.tools)
102+
103+
if reasoning:
104+
generator = f"{generator} (reasoning={reasoning})"
105+
101106
logger.info(f"🤖 {generator} | {name} v{version} with {tools} tools")
102107

103108
elif event.name == "before_tool_called":

0 commit comments

Comments
 (0)