1616import weakref
1717from dataclasses import dataclass , field
1818from pathlib import Path
19- from typing import TYPE_CHECKING , Annotated , Any , Literal
19+ from typing import TYPE_CHECKING , Annotated , Any , Literal , cast
2020
2121from langchain_core .messages import ToolMessage
2222from langchain_core .tools .base import BaseTool , ToolException
@@ -394,6 +394,7 @@ def __init__(
394394 tool_description : str | None = None ,
395395 shell_command : Sequence [str ] | str | None = None ,
396396 env : Mapping [str , Any ] | None = None ,
397+ name : str = "shell" ,
397398 ) -> None :
398399 """Initialize the middleware.
399400
@@ -412,8 +413,11 @@ def __init__(
412413 env: Optional environment variables to supply to the shell session. Values are
413414 coerced to strings before command execution. If omitted, the session inherits the
414415 parent process environment.
416+ name: Unique name for this middleware instance (default: "shell").
415417 """
416418 super ().__init__ ()
419+ self ._session : ShellSession | None = None
420+ self ._tempdir : tempfile .TemporaryDirectory | None = None
417421 self ._workspace_root = Path (workspace_root ) if workspace_root else None
418422 self ._shell_command = self ._normalize_shell_command (shell_command )
419423 self ._environment = self ._normalize_env (env )
@@ -425,13 +429,26 @@ def __init__(
425429 self ._redaction_rules : tuple [ResolvedRedactionRule , ...] = tuple (
426430 rule .resolve () for rule in rules
427431 )
432+ self ._name = name
428433 self ._startup_commands = self ._normalize_commands (startup_commands )
429434 self ._shutdown_commands = self ._normalize_commands (shutdown_commands )
430435
431436 description = tool_description or DEFAULT_TOOL_DESCRIPTION
432437 self ._tool = _PersistentShellTool (self , description = description )
433438 self .tools = [self ._tool ]
434439
440+ @property
441+ def name (self ) -> str :
442+ """Unique name of the middleware instance.
443+
444+ Used by LangGraph to identify nodes in the execution graph
445+ (e.g., `shell.before_agent`). Defaults to `"shell"`.
446+
447+ Returns:
448+ The configured name of this middleware.
449+ """
450+ return self ._name
451+
435452 @staticmethod
436453 def _normalize_commands (
437454 commands : tuple [str , ...] | list [str ] | str | None ,
@@ -467,9 +484,26 @@ def _normalize_env(env: Mapping[str, Any] | None) -> dict[str, str] | None:
467484 return normalized
468485
469486 def before_agent (self , state : ShellToolState , runtime : Runtime ) -> dict [str , Any ] | None : # noqa: ARG002
470- """Start the shell session and run startup commands."""
471- resources = self ._create_resources ()
472- return {"shell_session_resources" : resources }
487+ """"Prepare the shell session before agent execution.
488+
489+ Starts the persistent shell session if not already running and saves
490+ its configuration to checkpoint metadata for reliable restoration
491+ after human-in-the-loop (HIL) interrupts.
492+
493+ This method is called automatically by LangGraph before the agent
494+ runs. The session is cached in the middleware instance — not returned
495+ in state — to avoid serialization issues.
496+
497+ Args:
498+ state: Current agent state, used to access checkpoint metadata.
499+ runtime: LangGraph runtime (unused, but required by interface).
500+
501+ Returns:
502+ None: Session is managed internally.
503+ """
504+ self ._ensure_session ()
505+ self ._save_session_metadata (state )
506+ return None
473507
474508 async def abefore_agent (self , state : ShellToolState , runtime : Runtime ) -> dict [str , Any ] | None :
475509 """Async counterpart to `before_agent`."""
@@ -492,13 +526,80 @@ def _ensure_resources(self, state: ShellToolState) -> _SessionResources:
492526 if resources is not None and not isinstance (resources , _SessionResources ):
493527 resources = None
494528 if resources is None :
495- msg = (
496- "Shell session resources are unavailable. Ensure `before_agent` ran successfully "
497- "before invoking the shell tool."
498- )
499- raise ToolException (msg )
529+ return self ._get_or_create_resources ()
500530 return resources
531+ def _get_or_create_resources (self ) -> _SessionResources :
532+ if self ._session is None :
533+ self ._ensure_session ()
534+ session = cast ("ShellSession" , self ._session )
535+ return _SessionResources (
536+ session = session ,
537+ tempdir = self ._tempdir ,
538+ policy = self ._execution_policy ,
539+ )
501540
541+ def _ensure_session (self ) -> None :
542+ if self ._session is not None :
543+ return
544+ workspace = self ._workspace_root
545+ tempdir : tempfile .TemporaryDirectory | None = None
546+ if workspace is None :
547+ tempdir = tempfile .TemporaryDirectory (prefix = SHELL_TEMP_PREFIX )
548+ workspace = Path (tempdir .name )
549+ else :
550+ workspace = Path (workspace )
551+ workspace .mkdir (parents = True , exist_ok = True )
552+
553+ session = ShellSession (
554+ workspace ,
555+ self ._execution_policy ,
556+ self ._shell_command ,
557+ self ._environment or {},
558+ )
559+ session .start ()
560+ self ._run_startup_commands (session )
561+ self ._session = session
562+ self ._tempdir = tempdir
563+
564+ def _save_session_metadata (self , state : ShellToolState ) -> None :
565+ checkpoint = getattr (state , "checkpoint" , None )
566+ if checkpoint is None :
567+ return
568+ metadata = checkpoint .metadata
569+ if metadata is None :
570+ return
571+ if self ._session is None :
572+ return
573+ metadata ["shell_session" ] = {
574+ "workspace" : str (self ._session ._workspace ),
575+ "command" : self ._shell_command ,
576+ "env" : dict (self ._session ._environment ),
577+ "policy" : self ._execution_policy .to_dict (),
578+ }
579+
580+ def restore_from_metadata (self , state : ShellToolState ) -> None :
581+ """Restore shell session from checkpoint metadata on HIL resume."""
582+ checkpoint = getattr (state , "checkpoint" , None )
583+ if checkpoint is None :
584+ return
585+ data = checkpoint .metadata .get ("shell_session" )
586+ if not data :
587+ return
588+ try :
589+ policy = BaseExecutionPolicy .from_dict (data ["policy" ])
590+ workspace = Path (data ["workspace" ])
591+ session = ShellSession (
592+ workspace ,
593+ policy ,
594+ tuple (data ["command" ]),
595+ data ["env" ],
596+ )
597+ session .start ()
598+ self ._session = session
599+ self ._tempdir = None
600+ LOGGER .info ("Restored shell session from checkpoint" )
601+ except Exception :
602+ LOGGER .exception ("Failed to restore shell session" ) # ← logs traceback
502603 def _create_resources (self ) -> _SessionResources :
503604 workspace = self ._workspace_root
504605 tempdir : tempfile .TemporaryDirectory [str ] | None = None
0 commit comments