diff --git a/libs/langchain_v1/langchain/agents/factory.py b/libs/langchain_v1/langchain/agents/factory.py index e22f851bd8438..98f0cb8474603 100644 --- a/libs/langchain_v1/langchain/agents/factory.py +++ b/libs/langchain_v1/langchain/agents/factory.py @@ -17,6 +17,7 @@ from langgraph.typing import ContextT # noqa: TC002 from typing_extensions import NotRequired, Required, TypedDict, TypeVar +from langchain.agents.middleware.shell_tool import ShellToolMiddleware from langchain.agents.middleware.types import ( AgentMiddleware, AgentState, @@ -1309,6 +1310,20 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str end_destination=exit_node, can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"), ) + last_mw = middleware_w_before_agent[-1] + if isinstance(last_mw, ShellToolMiddleware): + mw: ShellToolMiddleware = last_mw + graph.add_node("restore_shell", mw.restore_from_metadata) + + graph.add_conditional_edges( + f"{mw.name}.before_agent", + lambda state: state.get("resume_from") == "interrupt", + { + True: "restore_shell", + False: loop_entry_node, + }, + ) + graph.add_edge("restore_shell", loop_entry_node) # Add before_model middleware edges if middleware_w_before_model: diff --git a/libs/langchain_v1/langchain/agents/middleware/_execution.py b/libs/langchain_v1/langchain/agents/middleware/_execution.py index f14235bf62785..8d97743ca4205 100644 --- a/libs/langchain_v1/langchain/agents/middleware/_execution.py +++ b/libs/langchain_v1/langchain/agents/middleware/_execution.py @@ -12,6 +12,7 @@ from collections.abc import Mapping, Sequence from dataclasses import dataclass, field from pathlib import Path +from typing import Any try: # pragma: no cover - optional dependency on POSIX platforms import resource @@ -84,6 +85,38 @@ def spawn( ) -> subprocess.Popen[str]: """Launch the persistent shell process.""" + def to_dict(self) -> dict[str, Any]: + """Convert the policy to a JSON-serialisable dict.""" + return { + "type": self.__class__.__name__, + "command_timeout": self.command_timeout, + "startup_timeout": self.startup_timeout, + "termination_timeout": self.termination_timeout, + "max_output_lines": self.max_output_lines, + "max_output_bytes": self.max_output_bytes, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> BaseExecutionPolicy: + """Re-create a policy from the dict stored in checkpoint metadata.""" + typ = data["type"] + kwargs = { + "command_timeout": data["command_timeout"], + "startup_timeout": data["startup_timeout"], + "termination_timeout": data["termination_timeout"], + "max_output_lines": data.get("max_output_lines"), + "max_output_bytes": data.get("max_output_bytes"), + } + + if typ == "HostExecutionPolicy": + return HostExecutionPolicy(**kwargs) + if typ == "DockerExecutionPolicy": + return DockerExecutionPolicy(**kwargs) + if typ == "CodexSandboxExecutionPolicy": + return CodexSandboxExecutionPolicy(**kwargs) + + return HostExecutionPolicy(**kwargs) + @dataclass class HostExecutionPolicy(BaseExecutionPolicy): diff --git a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py index 563ef2a2c39f7..d796a6e17de8e 100644 --- a/libs/langchain_v1/langchain/agents/middleware/shell_tool.py +++ b/libs/langchain_v1/langchain/agents/middleware/shell_tool.py @@ -16,7 +16,7 @@ import weakref from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Annotated, Any, Literal +from typing import TYPE_CHECKING, Annotated, Any, Literal, cast from langchain_core.messages import ToolMessage from langchain_core.tools.base import BaseTool, ToolException @@ -394,6 +394,7 @@ def __init__( tool_description: str | None = None, shell_command: Sequence[str] | str | None = None, env: Mapping[str, Any] | None = None, + name: str = "shell", ) -> None: """Initialize the middleware. @@ -412,8 +413,11 @@ def __init__( env: Optional environment variables to supply to the shell session. Values are coerced to strings before command execution. If omitted, the session inherits the parent process environment. + name: Unique name for this middleware instance (default: "shell"). """ super().__init__() + self._session: ShellSession | None = None + self._tempdir: tempfile.TemporaryDirectory | None = None self._workspace_root = Path(workspace_root) if workspace_root else None self._shell_command = self._normalize_shell_command(shell_command) self._environment = self._normalize_env(env) @@ -425,6 +429,7 @@ def __init__( self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple( rule.resolve() for rule in rules ) + self._name = name self._startup_commands = self._normalize_commands(startup_commands) self._shutdown_commands = self._normalize_commands(shutdown_commands) @@ -432,6 +437,18 @@ def __init__( self._tool = _PersistentShellTool(self, description=description) self.tools = [self._tool] + @property + def name(self) -> str: + """Unique name of the middleware instance. + + Used by LangGraph to identify nodes in the execution graph + (e.g., `shell.before_agent`). Defaults to `"shell"`. + + Returns: + The configured name of this middleware. + """ + return self._name + @staticmethod def _normalize_commands( commands: tuple[str, ...] | list[str] | str | None, @@ -467,9 +484,26 @@ def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None: return normalized def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002 - """Start the shell session and run startup commands.""" - resources = self._create_resources() - return {"shell_session_resources": resources} + """Prepare the shell session before agent execution. + + Starts the persistent shell session if not already running and saves + its configuration to checkpoint metadata for reliable restoration + after human-in-the-loop (HIL) interrupts. + + This method is called automatically by LangGraph before the agent + runs. The session is cached in the middleware instance — not returned + in state — to avoid serialization issues. + + Args: + state: Current agent state, used to access checkpoint metadata. + runtime: LangGraph runtime (unused, but required by interface). + + Returns: + None: Session is managed internally. + """ + self._ensure_session() + self._save_session_metadata(state) + return None async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: """Async counterpart to `before_agent`.""" @@ -487,17 +521,85 @@ async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None: """Async counterpart to `after_agent`.""" return self.after_agent(state, runtime) - def _ensure_resources(self, state: ShellToolState) -> _SessionResources: - resources = state.get("shell_session_resources") - if resources is not None and not isinstance(resources, _SessionResources): - resources = None - if resources is None: - msg = ( - "Shell session resources are unavailable. Ensure `before_agent` ran successfully " - "before invoking the shell tool." + def _ensure_resources(self, state: ShellToolState) -> _SessionResources: # noqa: ARG002 + """Always return live resources from middleware cache. + + State is ignored — session is managed internally to support restart and HIL resume. + """ + return self._get_or_create_resources() + + def _get_or_create_resources(self) -> _SessionResources: + if self._session is None: + self._ensure_session() + session = cast("ShellSession", self._session) + return _SessionResources( + session=session, + tempdir=self._tempdir, + policy=self._execution_policy, + ) + + def _ensure_session(self) -> None: + if self._session is not None: + return + workspace = self._workspace_root + tempdir: tempfile.TemporaryDirectory | None = None + if workspace is None: + tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX) + workspace = Path(tempdir.name) + else: + workspace = Path(workspace) + workspace.mkdir(parents=True, exist_ok=True) + + session = ShellSession( + workspace, + self._execution_policy, + self._shell_command, + self._environment or {}, + ) + session.start() + self._run_startup_commands(session) + self._session = session + self._tempdir = tempdir + + def _save_session_metadata(self, state: ShellToolState) -> None: + checkpoint = getattr(state, "checkpoint", None) + if checkpoint is None: + return + metadata = checkpoint.metadata + if metadata is None: + return + if self._session is None: + return + metadata["shell_session"] = { + "workspace": str(self._session._workspace), + "command": self._shell_command, + "env": dict(self._session._environment), + "policy": self._execution_policy.to_dict(), + } + + def restore_from_metadata(self, state: ShellToolState) -> None: + """Restore shell session from checkpoint metadata on HIL resume.""" + checkpoint = getattr(state, "checkpoint", None) + if checkpoint is None: + return + data = checkpoint.metadata.get("shell_session") + if not data: + return + try: + policy = BaseExecutionPolicy.from_dict(data["policy"]) + workspace = Path(data["workspace"]) + session = ShellSession( + workspace, + policy, + tuple(data["command"]), + data["env"], ) - raise ToolException(msg) - return resources + session.start() + self._session = session + self._tempdir = None + LOGGER.info("Restored shell session from checkpoint") + except Exception: + LOGGER.exception("Failed to restore shell session") # ← logs traceback def _create_resources(self) -> _SessionResources: workspace = self._workspace_root @@ -576,8 +678,17 @@ def _run_shell_tool( if payload.get("restart"): LOGGER.info("Restarting shell session on request.") try: - session.restart() - self._run_startup_commands(session) + session.stop(self._execution_policy.termination_timeout) + new_session = ShellSession( + session._workspace, + self._execution_policy, + self._shell_command, + self._environment or {}, + ) + new_session.start() + self._run_startup_commands(new_session) + self._session = new_session + resources.session = new_session except BaseException as err: LOGGER.exception("Restarting shell session failed; session remains unavailable.") msg = "Failed to restart shell session."