Skip to content

Commit d79a2f1

Browse files
committed
fix(agents): ShellToolMiddleware session lost on HIL resume
1 parent ef85161 commit d79a2f1

File tree

3 files changed

+157
-9
lines changed

3 files changed

+157
-9
lines changed

libs/langchain_v1/langchain/agents/factory.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from langgraph.typing import ContextT # noqa: TC002
1717
from typing_extensions import NotRequired, Required, TypedDict, TypeVar
1818

19+
from langchain.agents.middleware.shell_tool import ShellToolMiddleware
1920
from langchain.agents.middleware.types import (
2021
AgentMiddleware,
2122
AgentState,
@@ -1300,6 +1301,19 @@ async def amodel_node(state: AgentState, runtime: Runtime[ContextT]) -> dict[str
13001301
end_destination=exit_node,
13011302
can_jump_to=_get_can_jump_to(middleware_w_before_agent[-1], "before_agent"),
13021303
)
1304+
last_mw = middleware_w_before_agent[-1]
1305+
if isinstance(last_mw, ShellToolMiddleware):
1306+
graph.add_node("restore_shell", last_mw.restore_from_metadata) # type: ignore[attr-defined]
1307+
1308+
graph.add_conditional_edges(
1309+
f"{last_mw.name}.before_agent", # type: ignore[attr-defined]
1310+
lambda state: state.get("resume_from") == "interrupt",
1311+
{
1312+
True: "restore_shell",
1313+
False: loop_entry_node,
1314+
},
1315+
)
1316+
graph.add_edge("restore_shell", loop_entry_node)
13031317

13041318
# Add before_model middleware edges
13051319
if middleware_w_before_model:

libs/langchain_v1/langchain/agents/middleware/_execution.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from collections.abc import Mapping, Sequence
1313
from dataclasses import dataclass, field
1414
from pathlib import Path
15+
from typing import Any
1516

1617
try: # pragma: no cover - optional dependency on POSIX platforms
1718
import resource
@@ -84,6 +85,38 @@ def spawn(
8485
) -> subprocess.Popen[str]:
8586
"""Launch the persistent shell process."""
8687

88+
def to_dict(self) -> dict[str, Any]:
89+
"""Convert the policy to a JSON-serialisable dict."""
90+
return {
91+
"type": self.__class__.__name__,
92+
"command_timeout": self.command_timeout,
93+
"startup_timeout": self.startup_timeout,
94+
"termination_timeout": self.termination_timeout,
95+
"max_output_lines": self.max_output_lines,
96+
"max_output_bytes": self.max_output_bytes,
97+
}
98+
99+
@classmethod
100+
def from_dict(cls, data: dict[str, Any]) -> BaseExecutionPolicy:
101+
"""Re-create a policy from the dict stored in checkpoint metadata."""
102+
typ = data["type"]
103+
kwargs = {
104+
"command_timeout": data["command_timeout"],
105+
"startup_timeout": data["startup_timeout"],
106+
"termination_timeout": data["termination_timeout"],
107+
"max_output_lines": data.get("max_output_lines"),
108+
"max_output_bytes": data.get("max_output_bytes"),
109+
}
110+
111+
if typ == "HostExecutionPolicy":
112+
return HostExecutionPolicy(**kwargs)
113+
if typ == "DockerExecutionPolicy":
114+
return DockerExecutionPolicy(**kwargs)
115+
if typ == "CodexSandboxExecutionPolicy":
116+
return CodexSandboxExecutionPolicy(**kwargs)
117+
118+
return HostExecutionPolicy(**kwargs)
119+
87120

88121
@dataclass
89122
class HostExecutionPolicy(BaseExecutionPolicy):

libs/langchain_v1/langchain/agents/middleware/shell_tool.py

Lines changed: 110 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import weakref
1717
from dataclasses import dataclass, field
1818
from pathlib import Path
19-
from typing import TYPE_CHECKING, Annotated, Any, Literal
19+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
2020

2121
from langchain_core.messages import ToolMessage
2222
from langchain_core.tools.base import BaseTool, ToolException
@@ -394,6 +394,7 @@ def __init__(
394394
tool_description: str | None = None,
395395
shell_command: Sequence[str] | str | None = None,
396396
env: Mapping[str, Any] | None = None,
397+
name: str = "shell",
397398
) -> None:
398399
"""Initialize the middleware.
399400
@@ -412,8 +413,11 @@ def __init__(
412413
env: Optional environment variables to supply to the shell session. Values are
413414
coerced to strings before command execution. If omitted, the session inherits the
414415
parent process environment.
416+
name: Unique name for this middleware instance (default: "shell").
415417
"""
416418
super().__init__()
419+
self._session: ShellSession | None = None
420+
self._tempdir: tempfile.TemporaryDirectory | None = None
417421
self._workspace_root = Path(workspace_root) if workspace_root else None
418422
self._shell_command = self._normalize_shell_command(shell_command)
419423
self._environment = self._normalize_env(env)
@@ -425,13 +429,26 @@ def __init__(
425429
self._redaction_rules: tuple[ResolvedRedactionRule, ...] = tuple(
426430
rule.resolve() for rule in rules
427431
)
432+
self._name = name
428433
self._startup_commands = self._normalize_commands(startup_commands)
429434
self._shutdown_commands = self._normalize_commands(shutdown_commands)
430435

431436
description = tool_description or DEFAULT_TOOL_DESCRIPTION
432437
self._tool = _PersistentShellTool(self, description=description)
433438
self.tools = [self._tool]
434439

440+
@property
441+
def name(self) -> str:
442+
"""Unique name of the middleware instance.
443+
444+
Used by LangGraph to identify nodes in the execution graph
445+
(e.g., `shell.before_agent`). Defaults to `"shell"`.
446+
447+
Returns:
448+
The configured name of this middleware.
449+
"""
450+
return self._name
451+
435452
@staticmethod
436453
def _normalize_commands(
437454
commands: tuple[str, ...] | list[str] | str | None,
@@ -467,9 +484,26 @@ def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
467484
return normalized
468485

469486
def before_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
470-
"""Start the shell session and run startup commands."""
471-
resources = self._create_resources()
472-
return {"shell_session_resources": resources}
487+
""""Prepare the shell session before agent execution.
488+
489+
Starts the persistent shell session if not already running and saves
490+
its configuration to checkpoint metadata for reliable restoration
491+
after human-in-the-loop (HIL) interrupts.
492+
493+
This method is called automatically by LangGraph before the agent
494+
runs. The session is cached in the middleware instance — not returned
495+
in state — to avoid serialization issues.
496+
497+
Args:
498+
state: Current agent state, used to access checkpoint metadata.
499+
runtime: LangGraph runtime (unused, but required by interface).
500+
501+
Returns:
502+
None: Session is managed internally.
503+
"""
504+
self._ensure_session()
505+
self._save_session_metadata(state)
506+
return None
473507

474508
async def abefore_agent(self, state: ShellToolState, runtime: Runtime) -> dict[str, Any] | None:
475509
"""Async counterpart to `before_agent`."""
@@ -492,13 +526,80 @@ def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
492526
if resources is not None and not isinstance(resources, _SessionResources):
493527
resources = None
494528
if resources is None:
495-
msg = (
496-
"Shell session resources are unavailable. Ensure `before_agent` ran successfully "
497-
"before invoking the shell tool."
498-
)
499-
raise ToolException(msg)
529+
return self._get_or_create_resources()
500530
return resources
531+
def _get_or_create_resources(self) -> _SessionResources:
532+
if self._session is None:
533+
self._ensure_session()
534+
session = cast("ShellSession", self._session)
535+
return _SessionResources(
536+
session=session,
537+
tempdir=self._tempdir,
538+
policy=self._execution_policy,
539+
)
501540

541+
def _ensure_session(self) -> None:
542+
if self._session is not None:
543+
return
544+
workspace = self._workspace_root
545+
tempdir: tempfile.TemporaryDirectory | None = None
546+
if workspace is None:
547+
tempdir = tempfile.TemporaryDirectory(prefix=SHELL_TEMP_PREFIX)
548+
workspace = Path(tempdir.name)
549+
else:
550+
workspace = Path(workspace)
551+
workspace.mkdir(parents=True, exist_ok=True)
552+
553+
session = ShellSession(
554+
workspace,
555+
self._execution_policy,
556+
self._shell_command,
557+
self._environment or {},
558+
)
559+
session.start()
560+
self._run_startup_commands(session)
561+
self._session = session
562+
self._tempdir = tempdir
563+
564+
def _save_session_metadata(self, state: ShellToolState) -> None:
565+
checkpoint = getattr(state, "checkpoint", None)
566+
if checkpoint is None:
567+
return
568+
metadata = checkpoint.metadata
569+
if metadata is None:
570+
return
571+
if self._session is None:
572+
return
573+
metadata["shell_session"] = {
574+
"workspace": str(self._session._workspace),
575+
"command": self._shell_command,
576+
"env": dict(self._session._environment),
577+
"policy": self._execution_policy.to_dict(),
578+
}
579+
580+
def restore_from_metadata(self, state: ShellToolState) -> None:
581+
"""Restore shell session from checkpoint metadata on HIL resume."""
582+
checkpoint = getattr(state, "checkpoint", None)
583+
if checkpoint is None:
584+
return
585+
data = checkpoint.metadata.get("shell_session")
586+
if not data:
587+
return
588+
try:
589+
policy = BaseExecutionPolicy.from_dict(data["policy"])
590+
workspace = Path(data["workspace"])
591+
session = ShellSession(
592+
workspace,
593+
policy,
594+
tuple(data["command"]),
595+
data["env"],
596+
)
597+
session.start()
598+
self._session = session
599+
self._tempdir = None
600+
LOGGER.info("Restored shell session from checkpoint")
601+
except Exception:
602+
LOGGER.exception("Failed to restore shell session") # ← logs traceback
502603
def _create_resources(self) -> _SessionResources:
503604
workspace = self._workspace_root
504605
tempdir: tempfile.TemporaryDirectory[str] | None = None

0 commit comments

Comments
 (0)