Skip to content
15 changes: 15 additions & 0 deletions libs/langchain_v1/langchain/agents/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from langgraph.typing import ContextT # noqa: TC002
from typing_extensions import NotRequired, Required, TypedDict, TypeVar

from langchain.agents.middleware.shell_tool import ShellToolMiddleware
from langchain.agents.middleware.types import (
AgentMiddleware,
AgentState,
Expand Down Expand Up @@ -1309,6 +1310,20 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
end_destination=exit_node,
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
)
last_mw = middleware_w_before_agent[-1]
if isinstance(last_mw, ShellToolMiddleware):
mw: ShellToolMiddleware = last_mw
graph.add_node("restore_shell", mw.restore_from_metadata)

graph.add_conditional_edges(
f"{mw.name}.before_agent",
lambda state: state.get("resume_from") == "interrupt",
{
True: "restore_shell",
False: loop_entry_node,
},
)
graph.add_edge("restore_shell", loop_entry_node)

# Add before_model middleware edges
if middleware_w_before_model:
Expand Down
33 changes: 33 additions & 0 deletions libs/langchain_v1/langchain/agents/middleware/_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

try: # pragma: no cover - optional dependency on POSIX platforms
import resource
Expand Down Expand Up @@ -84,6 +85,38 @@ def spawn(
) -> subprocess.Popen[str]:
"""Launch the persistent shell process."""

def to_dict(self) -> dict[str, Any]:
"""Convert the policy to a JSON-serialisable dict."""
return {
"type": self.__class__.__name__,
"command_timeout": self.command_timeout,
"startup_timeout": self.startup_timeout,
"termination_timeout": self.termination_timeout,
"max_output_lines": self.max_output_lines,
"max_output_bytes": self.max_output_bytes,
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> BaseExecutionPolicy:
"""Re-create a policy from the dict stored in checkpoint metadata."""
typ = data["type"]
kwargs = {
"command_timeout": data["command_timeout"],
"startup_timeout": data["startup_timeout"],
"termination_timeout": data["termination_timeout"],
"max_output_lines": data.get("max_output_lines"),
"max_output_bytes": data.get("max_output_bytes"),
}

if typ == "HostExecutionPolicy":
return HostExecutionPolicy(**kwargs)
if typ == "DockerExecutionPolicy":
return DockerExecutionPolicy(**kwargs)
if typ == "CodexSandboxExecutionPolicy":
return CodexSandboxExecutionPolicy(**kwargs)

return HostExecutionPolicy(**kwargs)


@dataclass
class HostExecutionPolicy(BaseExecutionPolicy):
Expand Down
143 changes: 127 additions & 16 deletions libs/langchain_v1/langchain/agents/middleware/shell_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import weakref
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Literal
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast

from langchain_core.messages import ToolMessage
from langchain_core.tools.base import BaseTool, ToolException
Expand Down Expand Up @@ -394,6 +394,7 @@ def __init__(
tool_description: str | None = None,
shell_command: Sequence[str] | str | None = None,
env: Mapping[str, Any] | None = None,
name: str = "shell",
) -> None:
"""Initialize the middleware.

Expand All @@ -412,8 +413,11 @@ def __init__(
env: Optional environment variables to supply to the shell session. Values are
coerced to strings before command execution. If omitted, the session inherits the
parent process environment.
name: Unique name for this middleware instance (default: "shell").
"""
super().__init__()
self._session: ShellSession | None = None
self._tempdir: tempfile.TemporaryDirectory | None = None
self._workspace_root = Path(workspace_root) if workspace_root else None
self._shell_command = self._normalize_shell_command(shell_command)
self._environment = self._normalize_env(env)
Expand All @@ -425,13 +429,26 @@ def __init__(
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
rule.resolve() for rule in rules
)
self._name = name
self._startup_commands = self._normalize_commands(startup_commands)
self._shutdown_commands = self._normalize_commands(shutdown_commands)

description = tool_description or DEFAULT_TOOL_DESCRIPTION
self._tool = _PersistentShellTool(self, description=description)
self.tools = [self._tool]

@property
def name(self) -> str:
"""Unique name of the middleware instance.

Used by LangGraph to identify nodes in the execution graph
(e.g., `shell.before_agent`). Defaults to `"shell"`.

Returns:
The configured name of this middleware.
"""
return self._name

@staticmethod
def _normalize_commands(
commands: tuple[str, ...] | list[str] | str | None,
Expand Down Expand Up @@ -467,9 +484,26 @@ def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
return normalized

def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
"""Start the shell session and run startup commands."""
resources = self._create_resources()
return {"shell_session_resources": resources}
"""Prepare the shell session before agent execution.

Starts the persistent shell session if not already running and saves
its configuration to checkpoint metadata for reliable restoration
after human-in-the-loop (HIL) interrupts.

This method is called automatically by LangGraph before the agent
runs. The session is cached in the middleware instance — not returned
in state — to avoid serialization issues.

Args:
state: Current agent state, used to access checkpoint metadata.
runtime: LangGraph runtime (unused, but required by interface).

Returns:
None: Session is managed internally.
"""
self._ensure_session()
self._save_session_metadata(state)
return None

async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
"""Async counterpart to `before_agent`."""
Expand All @@ -487,17 +521,85 @@ async def aafter_agent(self, state: ShellToolState, runtime: Runtime) -> None:
"""Async counterpart to `after_agent`."""
return self.after_agent(state, runtime)

def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
resources = state.get("shell_session_resources")
if resources is not None and not isinstance(resources, _SessionResources):
resources = None
if resources is None:
msg = (
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
"before invoking the shell tool."
def _ensure_resources(self, state: ShellToolState) -> _SessionResources: # noqa: ARG002
"""Always return live resources from middleware cache.

State is ignored — session is managed internally to support restart and HIL resume.
"""
return self._get_or_create_resources()

def _get_or_create_resources(self) -> _SessionResources:
if self._session is None:
self._ensure_session()
session = cast("ShellSession", self._session)
return _SessionResources(
session=session,
tempdir=self._tempdir,
policy=self._execution_policy,
)

def _ensure_session(self) -> None:
if self._session is not None:
return
workspace = self._workspace_root
tempdir: tempfile.TemporaryDirectory | None = None
if workspace is None:
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
workspace = Path(tempdir.name)
else:
workspace = Path(workspace)
workspace.mkdir(parents=True, exist_ok=True)

session = ShellSession(
workspace,
self._execution_policy,
self._shell_command,
self._environment or {},
)
session.start()
self._run_startup_commands(session)
self._session = session
self._tempdir = tempdir

def _save_session_metadata(self, state: ShellToolState) -> None:
checkpoint = getattr(state, "checkpoint", None)
if checkpoint is None:
return
metadata = checkpoint.metadata
if metadata is None:
return
if self._session is None:
return
metadata["shell_session"] = {
"workspace": str(self._session._workspace),
"command": self._shell_command,
"env": dict(self._session._environment),
"policy": self._execution_policy.to_dict(),
}

def restore_from_metadata(self, state: ShellToolState) -> None:
"""Restore shell session from checkpoint metadata on HIL resume."""
checkpoint = getattr(state, "checkpoint", None)
if checkpoint is None:
return
data = checkpoint.metadata.get("shell_session")
if not data:
return
try:
policy = BaseExecutionPolicy.from_dict(data["policy"])
workspace = Path(data["workspace"])
session = ShellSession(
workspace,
policy,
tuple(data["command"]),
data["env"],
)
raise ToolException(msg)
return resources
session.start()
self._session = session
self._tempdir = None
LOGGER.info("Restored shell session from checkpoint")
except Exception:
LOGGER.exception("Failed to restore shell session") # ← logs traceback

def _create_resources(self) -> _SessionResources:
workspace = self._workspace_root
Expand Down Expand Up @@ -576,8 +678,17 @@ def _run_shell_tool(
if payload.get("restart"):
LOGGER.info("Restarting shell session on request.")
try:
session.restart()
self._run_startup_commands(session)
session.stop(self._execution_policy.termination_timeout)
new_session = ShellSession(
session._workspace,
self._execution_policy,
self._shell_command,
self._environment or {},
)
new_session.start()
self._run_startup_commands(new_session)
self._session = new_session
resources.session = new_session
except BaseException as err:
LOGGER.exception("Restarting shell session failed; session remains unavailable.")
msg = "Failed to restart shell session."
Expand Down