Skip to content

Commit 1bc8802

Browse files
fix(anthropic): execute bash + file tools via tool node (#33960)
* use `override` instead of directly patching things on `ModelRequest` * rely on `ToolNode` for execution of tools related to said middleware, using `wrap_model_call` to inject the relevant claude tool specs + allowing tool node to forward them along to corresponding langchain tool implementations * making the same change for the native shell tool middleware * allowing shell tool middleware to specify a name for the shell tool (negative diff then for claude bash middleware) long term I think the solution might be to attach metadata to a tool to map the provider spec to a langchain implementation, which we could also take some lessons from on the MCP front.
1 parent d294235 commit 1bc8802

File tree

4 files changed

+307
-407
lines changed

4 files changed

+307
-407
lines changed

libs/langchain_v1/langchain/agents/middleware/shell_tool.py

Lines changed: 38 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,17 @@
1111
import tempfile
1212
import threading
1313
import time
14-
import typing
1514
import uuid
1615
import weakref
1716
from dataclasses import dataclass, field
1817
from pathlib import Path
1918
from typing import TYPE_CHECKING, Annotated, Any, Literal
2019

2120
from langchain_core.messages import ToolMessage
22-
from langchain_core.tools.base import BaseTool, ToolException
21+
from langchain_core.tools.base import ToolException
2322
from langgraph.channels.untracked_value import UntrackedValue
2423
from pydantic import BaseModel, model_validator
24+
from pydantic.json_schema import SkipJsonSchema
2525
from typing_extensions import NotRequired
2626

2727
from langchain.agents.middleware._execution import (
@@ -38,14 +38,13 @@
3838
ResolvedRedactionRule,
3939
)
4040
from langchain.agents.middleware.types import AgentMiddleware, AgentState, PrivateStateAttr
41+
from langchain.tools import ToolRuntime, tool
4142

4243
if TYPE_CHECKING:
4344
from collections.abc import Mapping, Sequence
4445

4546
from langgraph.runtime import Runtime
46-
from langgraph.types import Command
4747

48-
from langchain.agents.middleware.types import ToolCallRequest
4948

5049
LOGGER = logging.getLogger(__name__)
5150
_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
@@ -59,6 +58,7 @@
5958
"session remains stable. Outputs may be truncated when they become very large, and long "
6059
"running commands will be terminated once their configured timeout elapses."
6160
)
61+
SHELL_TOOL_NAME = "shell"
6262

6363

6464
def _cleanup_resources(
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
334334
"""Input schema for the persistent shell tool."""
335335

336336
command: str | None = None
337+
"""The shell command to execute."""
338+
337339
restart: bool | None = None
340+
"""Whether to restart the shell session."""
341+
342+
runtime: Annotated[Any, SkipJsonSchema] = None
343+
"""The runtime for the shell tool.
344+
345+
Included as a workaround at the moment bc args_schema doesn't work with
346+
injected ToolRuntime.
347+
"""
338348

339349
@model_validator(mode="after")
340350
def validate_payload(self) -> _ShellToolInput:
@@ -347,24 +357,6 @@ def validate_payload(self) -> _ShellToolInput:
347357
return self
348358

349359

350-
class _PersistentShellTool(BaseTool):
351-
"""Tool wrapper that relies on middleware interception for execution."""
352-
353-
name: str = "shell"
354-
description: str = DEFAULT_TOOL_DESCRIPTION
355-
args_schema: type[BaseModel] = _ShellToolInput
356-
357-
def __init__(self, middleware: ShellToolMiddleware, description: str | None = None) -> None:
358-
super().__init__()
359-
self._middleware = middleware
360-
if description is not None:
361-
self.description = description
362-
363-
def _run(self, **_: Any) -> Any: # pragma: no cover - executed via middleware wrapper
364-
msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
365-
raise RuntimeError(msg)
366-
367-
368360
class ShellToolMiddleware(AgentMiddleware[ShellToolState, Any]):
369361
"""Middleware that registers a persistent shell tool for agents.
370362
@@ -393,6 +385,7 @@ def __init__(
393385
execution_policy: BaseExecutionPolicy | None = None,
394386
redaction_rules: tuple[RedactionRule, ...] | list[RedactionRule] | None = None,
395387
tool_description: str | None = None,
388+
tool_name: str = SHELL_TOOL_NAME,
396389
shell_command: Sequence[str] | str | None = None,
397390
env: Mapping[str, Any] | None = None,
398391
) -> None:
@@ -414,6 +407,9 @@ def __init__(
414407
returning it to the model.
415408
tool_description: Optional override for the registered shell tool
416409
description.
410+
tool_name: Name for the registered shell tool.
411+
412+
Defaults to `"shell"`.
417413
shell_command: Optional shell executable (string) or argument sequence used
418414
to launch the persistent session.
419415
@@ -425,6 +421,7 @@ def __init__(
425421
"""
426422
super().__init__()
427423
self._workspace_root = Path(workspace_root) if workspace_root else None
424+
self._tool_name = tool_name
428425
self._shell_command = self._normalize_shell_command(shell_command)
429426
self._environment = self._normalize_env(env)
430427
if execution_policy is not None:
@@ -438,9 +435,25 @@ def __init__(
438435
self._startup_commands = self._normalize_commands(startup_commands)
439436
self._shutdown_commands = self._normalize_commands(shutdown_commands)
440437

438+
# Create a proper tool that executes directly (no interception needed)
441439
description = tool_description or DEFAULT_TOOL_DESCRIPTION
442-
self._tool = _PersistentShellTool(self, description=description)
443-
self.tools = [self._tool]
440+
441+
@tool(self._tool_name, args_schema=_ShellToolInput, description=description)
442+
def shell_tool(
443+
*,
444+
runtime: ToolRuntime[None, ShellToolState],
445+
command: str | None = None,
446+
restart: bool = False,
447+
) -> ToolMessage | str:
448+
resources = self._ensure_resources(runtime.state)
449+
return self._run_shell_tool(
450+
resources,
451+
{"command": command, "restart": restart},
452+
tool_call_id=runtime.tool_call_id,
453+
)
454+
455+
self._shell_tool = shell_tool
456+
self.tools = [self._shell_tool]
444457

445458
@staticmethod
446459
def _normalize_commands(
@@ -669,37 +682,6 @@ def _run_shell_tool(
669682
artifact=artifact,
670683
)
671684

672-
def wrap_tool_call(
673-
self,
674-
request: ToolCallRequest,
675-
handler: typing.Callable[[ToolCallRequest], ToolMessage | Command],
676-
) -> ToolMessage | Command:
677-
"""Intercept local shell tool calls and execute them via the managed session."""
678-
if isinstance(request.tool, _PersistentShellTool):
679-
resources = self._ensure_resources(request.state)
680-
return self._run_shell_tool(
681-
resources,
682-
request.tool_call["args"],
683-
tool_call_id=request.tool_call.get("id"),
684-
)
685-
return handler(request)
686-
687-
async def awrap_tool_call(
688-
self,
689-
request: ToolCallRequest,
690-
handler: typing.Callable[[ToolCallRequest], typing.Awaitable[ToolMessage | Command]],
691-
) -> ToolMessage | Command:
692-
"""Async intercept local shell tool calls and execute them via the managed session."""
693-
# The sync version already handles all the work, no need for async-specific logic
694-
if isinstance(request.tool, _PersistentShellTool):
695-
resources = self._ensure_resources(request.state)
696-
return self._run_shell_tool(
697-
resources,
698-
request.tool_call["args"],
699-
tool_call_id=request.tool_call.get("id"),
700-
)
701-
return await handler(request)
702-
703685
def _format_tool_message(
704686
self,
705687
content: str,
@@ -714,7 +696,7 @@ def _format_tool_message(
714696
return ToolMessage(
715697
content=content,
716698
tool_call_id=tool_call_id,
717-
name=self._tool.name,
699+
name=self._tool_name,
718700
status=status,
719701
artifact=artifact,
720702
)

0 commit comments

Comments
 (0)