1111import tempfile
1212import threading
1313import time
14- import typing
1514import uuid
1615import weakref
1716from dataclasses import dataclass , field
1817from pathlib import Path
1918from typing import TYPE_CHECKING , Annotated , Any , Literal
2019
2120from langchain_core .messages import ToolMessage
22- from langchain_core .tools .base import BaseTool , ToolException
21+ from langchain_core .tools .base import ToolException
2322from langgraph .channels .untracked_value import UntrackedValue
2423from pydantic import BaseModel , model_validator
24+ from pydantic .json_schema import SkipJsonSchema
2525from typing_extensions import NotRequired
2626
2727from langchain .agents .middleware ._execution import (
3838 ResolvedRedactionRule ,
3939)
4040from langchain .agents .middleware .types import AgentMiddleware , AgentState , PrivateStateAttr
41+ from langchain .tools import ToolRuntime , tool
4142
4243if TYPE_CHECKING :
4344 from collections .abc import Mapping , Sequence
4445
4546 from langgraph .runtime import Runtime
46- from langgraph .types import Command
4747
48- from langchain .agents .middleware .types import ToolCallRequest
4948
5049LOGGER = logging .getLogger (__name__ )
5150_DONE_MARKER_PREFIX = "__LC_SHELL_DONE__"
5958 "session remains stable. Outputs may be truncated when they become very large, and long "
6059 "running commands will be terminated once their configured timeout elapses."
6160)
61+ SHELL_TOOL_NAME = "shell"
6262
6363
6464def _cleanup_resources (
@@ -334,7 +334,17 @@ class _ShellToolInput(BaseModel):
334334 """Input schema for the persistent shell tool."""
335335
336336 command : str | None = None
337+ """The shell command to execute."""
338+
337339 restart : bool | None = None
340+ """Whether to restart the shell session."""
341+
342+ runtime : Annotated [Any , SkipJsonSchema ] = None
343+ """The runtime for the shell tool.
344+
345+ Included as a workaround at the moment bc args_schema doesn't work with
346+ injected ToolRuntime.
347+ """
338348
339349 @model_validator (mode = "after" )
340350 def validate_payload (self ) -> _ShellToolInput :
@@ -347,24 +357,6 @@ def validate_payload(self) -> _ShellToolInput:
347357 return self
348358
349359
350- class _PersistentShellTool (BaseTool ):
351- """Tool wrapper that relies on middleware interception for execution."""
352-
353- name : str = "shell"
354- description : str = DEFAULT_TOOL_DESCRIPTION
355- args_schema : type [BaseModel ] = _ShellToolInput
356-
357- def __init__ (self , middleware : ShellToolMiddleware , description : str | None = None ) -> None :
358- super ().__init__ ()
359- self ._middleware = middleware
360- if description is not None :
361- self .description = description
362-
363- def _run (self , ** _ : Any ) -> Any : # pragma: no cover - executed via middleware wrapper
364- msg = "Persistent shell tool execution should be intercepted via middleware wrappers."
365- raise RuntimeError (msg )
366-
367-
368360class ShellToolMiddleware (AgentMiddleware [ShellToolState , Any ]):
369361 """Middleware that registers a persistent shell tool for agents.
370362
@@ -393,6 +385,7 @@ def __init__(
393385 execution_policy : BaseExecutionPolicy | None = None ,
394386 redaction_rules : tuple [RedactionRule , ...] | list [RedactionRule ] | None = None ,
395387 tool_description : str | None = None ,
388+ tool_name : str = SHELL_TOOL_NAME ,
396389 shell_command : Sequence [str ] | str | None = None ,
397390 env : Mapping [str , Any ] | None = None ,
398391 ) -> None :
@@ -414,6 +407,9 @@ def __init__(
414407 returning it to the model.
415408 tool_description: Optional override for the registered shell tool
416409 description.
410+ tool_name: Name for the registered shell tool.
411+
412+ Defaults to `"shell"`.
417413 shell_command: Optional shell executable (string) or argument sequence used
418414 to launch the persistent session.
419415
@@ -425,6 +421,7 @@ def __init__(
425421 """
426422 super ().__init__ ()
427423 self ._workspace_root = Path (workspace_root ) if workspace_root else None
424+ self ._tool_name = tool_name
428425 self ._shell_command = self ._normalize_shell_command (shell_command )
429426 self ._environment = self ._normalize_env (env )
430427 if execution_policy is not None :
@@ -438,9 +435,25 @@ def __init__(
438435 self ._startup_commands = self ._normalize_commands (startup_commands )
439436 self ._shutdown_commands = self ._normalize_commands (shutdown_commands )
440437
438+ # Create a proper tool that executes directly (no interception needed)
441439 description = tool_description or DEFAULT_TOOL_DESCRIPTION
442- self ._tool = _PersistentShellTool (self , description = description )
443- self .tools = [self ._tool ]
440+
441+ @tool (self ._tool_name , args_schema = _ShellToolInput , description = description )
442+ def shell_tool (
443+ * ,
444+ runtime : ToolRuntime [None , ShellToolState ],
445+ command : str | None = None ,
446+ restart : bool = False ,
447+ ) -> ToolMessage | str :
448+ resources = self ._ensure_resources (runtime .state )
449+ return self ._run_shell_tool (
450+ resources ,
451+ {"command" : command , "restart" : restart },
452+ tool_call_id = runtime .tool_call_id ,
453+ )
454+
455+ self ._shell_tool = shell_tool
456+ self .tools = [self ._shell_tool ]
444457
445458 @staticmethod
446459 def _normalize_commands (
@@ -669,37 +682,6 @@ def _run_shell_tool(
669682 artifact = artifact ,
670683 )
671684
672- def wrap_tool_call (
673- self ,
674- request : ToolCallRequest ,
675- handler : typing .Callable [[ToolCallRequest ], ToolMessage | Command ],
676- ) -> ToolMessage | Command :
677- """Intercept local shell tool calls and execute them via the managed session."""
678- if isinstance (request .tool , _PersistentShellTool ):
679- resources = self ._ensure_resources (request .state )
680- return self ._run_shell_tool (
681- resources ,
682- request .tool_call ["args" ],
683- tool_call_id = request .tool_call .get ("id" ),
684- )
685- return handler (request )
686-
687- async def awrap_tool_call (
688- self ,
689- request : ToolCallRequest ,
690- handler : typing .Callable [[ToolCallRequest ], typing .Awaitable [ToolMessage | Command ]],
691- ) -> ToolMessage | Command :
692- """Async intercept local shell tool calls and execute them via the managed session."""
693- # The sync version already handles all the work, no need for async-specific logic
694- if isinstance (request .tool , _PersistentShellTool ):
695- resources = self ._ensure_resources (request .state )
696- return self ._run_shell_tool (
697- resources ,
698- request .tool_call ["args" ],
699- tool_call_id = request .tool_call .get ("id" ),
700- )
701- return await handler (request )
702-
703685 def _format_tool_message (
704686 self ,
705687 content : str ,
@@ -714,7 +696,7 @@ def _format_tool_message(
714696 return ToolMessage (
715697 content = content ,
716698 tool_call_id = tool_call_id ,
717- name = self ._tool . name ,
699+ name = self ._tool_name ,
718700 status = status ,
719701 artifact = artifact ,
720702 )
0 commit comments