diff --git a/libs/langgraph-checkpoint-aws/README.md b/libs/langgraph-checkpoint-aws/README.md index 56ea921d..878a7b85 100644 --- a/libs/langgraph-checkpoint-aws/README.md +++ b/libs/langgraph-checkpoint-aws/README.md @@ -53,9 +53,37 @@ config = {"configurable": {"thread_id": session_id}} graph.invoke(1, config) ``` + +when invoking the graph asynchronously + +```python +from langgraph.graph import StateGraph +from langgraph_checkpoint_aws.async_saver import AsyncBedrockSessionSaver + +# Initialize the saver +session_saver = AsyncBedrockSessionSaver( + region_name="us-west-2", # Your AWS region + credentials_profile_name="default", # Optional: AWS credentials profile +) + +# Create a session +session_create_response = await session_saver.session_client.create_session() +session_id = session_response.session_id + +# Use with LangGraph +builder = StateGraph(int) +builder.add_node("add_one", lambda x: x + 1) +builder.set_entry_point("add_one") +builder.set_finish_point("add_one") + +graph = builder.compile(checkpointer=session_saver) +config = {"configurable": {"thread_id": session_id}} +graph.ainvoke(1, config) +``` + ## Configuration Options -`BedrockSessionSaver` accepts the following parameters: +`BedrockSessionSaver` and `AsyncBedrockSessionSaver` accepts the following parameters: ```python def __init__( diff --git a/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_saver.py b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_saver.py new file mode 100644 index 00000000..2235724f --- /dev/null +++ b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_saver.py @@ -0,0 +1,573 @@ +import datetime +import json +from collections.abc import AsyncIterator, Sequence +from typing import Any, Optional + +from botocore.config import Config +from botocore.exceptions import ClientError +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id, +) +from pydantic import SecretStr + +from langgraph_checkpoint_aws.async_session import AsyncBedrockAgentRuntimeSessionClient +from langgraph_checkpoint_aws.constants import CHECKPOINT_PREFIX +from langgraph_checkpoint_aws.models import ( + BedrockSessionContentBlock, + CreateInvocationRequest, + GetInvocationStepRequest, + InvocationStep, + InvocationStepPayload, + ListInvocationStepsRequest, + PutInvocationStepRequest, + SessionCheckpoint, + SessionPendingWrite, +) +from langgraph_checkpoint_aws.utils import ( + construct_checkpoint_tuple, + create_session_checkpoint, + deserialize_data, + generate_checkpoint_id, + generate_write_id, + process_write_operations, + process_writes_invocation_content_blocks, + transform_pending_task_writes, +) + + +class AsyncBedrockSessionSaver(BaseCheckpointSaver): + """Asynchronously saves and retrieves checkpoints using Amazon Bedrock Agent Runtime sessions. + + This class provides async functionality to persist checkpoint data and writes to Bedrock Agent Runtime sessions. + It handles creating invocations, managing checkpoint data, and tracking pending writes. + + Args: + region_name: AWS region name + credentials_profile_name: AWS credentials profile name + aws_access_key_id: AWS access key ID + aws_secret_access_key: AWS secret access key + aws_session_token: AWS session token + endpoint_url: Custom endpoint URL for the Bedrock service + config: Botocore config object + """ + + def __init__( + self, + region_name: Optional[str] = None, + credentials_profile_name: Optional[str] = None, + aws_access_key_id: Optional[SecretStr] = None, + aws_secret_access_key: Optional[SecretStr] = None, + aws_session_token: Optional[SecretStr] = None, + endpoint_url: Optional[str] = None, + config: Optional[Config] = None, + ) -> None: + super().__init__() + self.session_client = AsyncBedrockAgentRuntimeSessionClient( + region_name=region_name, + credentials_profile_name=credentials_profile_name, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_session_token=aws_session_token, + endpoint_url=endpoint_url, + config=config, + ) + + async def _create_session_invocation(self, thread_id: str, invocation_id: str): + """Asynchronously create a new invocation if one doesn't already exist. + + Args: + thread_id: The session identifier + invocation_id: The unique invocation identifier + + Raises: + ClientError: If creation fails for reasons other than the invocation already existing + """ + try: + await self.session_client.create_invocation( + CreateInvocationRequest( + session_identifier=thread_id, + invocation_id=invocation_id, + ) + ) + except ClientError as e: + if e.response["Error"]["Code"] != "ConflictException": + raise e + + async def _get_checkpoint_pending_writes( + self, thread_id: str, checkpoint_ns: str, checkpoint_id: str + ) -> list[SessionPendingWrite]: + """Asynchronously retrieve pending write operations for a given checkpoint from the Bedrock session. + + This method retrieves any pending write operations that were stored for a specific checkpoint. + It first gets the most recent invocation step, then retrieves the full details of that step, + and finally parses the content blocks to reconstruct the PendingWrite objects. + + Args: + thread_id: Session thread identifier used to locate the checkpoint data + checkpoint_ns: Namespace that groups related checkpoints together + checkpoint_id: Unique identifier for the specific checkpoint to retrieve + + Returns: + List of PendingWrite objects containing task_id, channel, value, task_path and write_idx. + Returns empty list if no pending writes are found. + """ + # Generate unique ID for the write operation + writes_id = generate_write_id(checkpoint_ns, checkpoint_id) + + try: + # Retrieve most recent invocation step (limit 1) for this writes_id + invocation_steps = await self.session_client.list_invocation_steps( + ListInvocationStepsRequest( + session_identifier=thread_id, + invocation_identifier=writes_id, + max_results=1, + ) + ) + invocation_step_summaries = invocation_steps.invocation_step_summaries + + # Return empty list if no steps found + if len(invocation_step_summaries) == 0: + return [] + + # Get complete details for the most recent step + invocation_step = await self.session_client.get_invocation_step( + GetInvocationStepRequest( + session_identifier=thread_id, + invocation_identifier=writes_id, + invocation_step_id=invocation_step_summaries[0].invocation_step_id, + ) + ) + + return process_writes_invocation_content_blocks( + invocation_step.invocation_step.payload.content_blocks, self.serde + ) + + except ClientError as e: + # Return empty list if resource not found, otherwise re-raise error + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return [] + raise e + + async def _save_invocation_step( + self, + thread_id: str, + invocation_identifier: str, + invocation_step_id: Optional[str], + payload: InvocationStepPayload, + ) -> None: + """Asynchronously persist an invocation step and its payload to the Bedrock session store. + + This method stores a single invocation step along with its associated payload data + in the Bedrock session. The step is timestamped with the current UTC time. + + Args: + thread_id: Unique identifier for the session thread + invocation_identifier: Identifier for the specific invocation + invocation_step_id: Unique identifier for this step within the invocation + payload: InvocationStepPayload object containing the content blocks to store + + Returns: + None + """ + await self.session_client.put_invocation_step( + PutInvocationStepRequest( + session_identifier=thread_id, + invocation_identifier=invocation_identifier, + invocation_step_id=invocation_step_id, + invocation_step_time=datetime.datetime.now(datetime.timezone.utc), + payload=payload, + ) + ) + + async def _find_most_recent_checkpoint_step( + self, thread_id: str, invocation_id: str + ) -> Optional[InvocationStep]: + """Asynchronously retrieve the most recent checkpoint step from a session's invocation history. + + Iterates through all invocation steps in reverse chronological order until it finds + a step with a checkpoint payload type. Uses pagination to handle large result sets. + + Args: + thread_id: The unique identifier for the session thread + invocation_id: The identifier for the specific invocation to search + + Returns: + InvocationStep object if a checkpoint is found, None otherwise + """ + next_token = None + while True: + # Get batch of invocation steps using pagination token if available + invocation_steps = await self.session_client.list_invocation_steps( + ListInvocationStepsRequest( + session_identifier=thread_id, + invocation_identifier=invocation_id, + next_token=next_token, + ) + ) + + # Return None if no steps found in this batch + if len(invocation_steps.invocation_step_summaries) == 0: + return None + + # Check each step in the batch for checkpoint type + for invocation_step_summary in invocation_steps.invocation_step_summaries: + invocation_step = await self.session_client.get_invocation_step( + GetInvocationStepRequest( + session_identifier=thread_id, + invocation_identifier=invocation_id, + invocation_step_id=invocation_step_summary.invocation_step_id, + ) + ) + + # Parse the step payload and check if it's a checkpoint + step_payload = json.loads( + invocation_step.invocation_step.payload.content_blocks[0].text + ) + if step_payload["step_type"] == CHECKPOINT_PREFIX: + return invocation_step.invocation_step + + # Get token for next batch of results + next_token = invocation_steps.next_token + if next_token is None: + return None + + async def _get_checkpoint_step( + self, thread_id: str, invocation_id: str, checkpoint_id: Optional[str] = None + ) -> Optional[InvocationStep]: + """Asynchronously retrieve checkpoint step data. + + Args: + thread_id: Session thread identifier + invocation_id: Invocation identifier + checkpoint_id: Optional checkpoint identifier + + Returns: + InvocationStep if found, None otherwise + """ + if checkpoint_id is None: + step = await self._find_most_recent_checkpoint_step( + thread_id, invocation_id + ) + if step is None: + return None + return step + + response = await self.session_client.get_invocation_step( + GetInvocationStepRequest( + session_identifier=thread_id, + invocation_identifier=invocation_id, + invocation_step_id=checkpoint_id, + ) + ) + return response.invocation_step + + async def _get_task_sends( + self, thread_id: str, checkpoint_ns: str, parent_checkpoint_id: Optional[str] + ) -> list: + """Asynchronously get sorted task sends for parent checkpoint. + + Args: + thread_id: Session thread identifier + checkpoint_ns: Checkpoint namespace + parent_checkpoint_id: Parent checkpoint identifier + + Returns: + Sorted list of task sends + """ + if not parent_checkpoint_id: + return [] + + pending_writes = await self._get_checkpoint_pending_writes( + thread_id, checkpoint_ns, parent_checkpoint_id + ) + return transform_pending_task_writes(pending_writes) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously retrieve a checkpoint tuple from the Bedrock session. + + This function retrieves checkpoint data from the session, processes it and returns + a structured CheckpointTuple containing the checkpoint state and metadata. + + Args: + config (RunnableConfig): Configuration containing thread_id and optional checkpoint_ns. + + Returns: + Optional[CheckpointTuple]: Structured checkpoint data if found, None otherwise. + """ + session_thread_id = config["configurable"]["thread_id"] + checkpoint_namespace = config["configurable"].get("checkpoint_ns", "") + checkpoint_identifier = get_checkpoint_id(config) + + invocation_id = generate_checkpoint_id(checkpoint_namespace) + + try: + invocation_step = await self._get_checkpoint_step( + session_thread_id, invocation_id, checkpoint_identifier + ) + if invocation_step is None: + return None + + session_checkpoint = SessionCheckpoint( + **json.loads(invocation_step.payload.content_blocks[0].text) + ) + + pending_write_ops = await self._get_checkpoint_pending_writes( + session_thread_id, + checkpoint_namespace, + invocation_step.invocation_step_id, + ) + + task_sends = await self._get_task_sends( + session_thread_id, + checkpoint_namespace, + session_checkpoint.parent_checkpoint_id, + ) + + return construct_checkpoint_tuple( + session_thread_id, + checkpoint_namespace, + session_checkpoint, + pending_write_ops, + task_sends, + self.serde, + ) + + except ClientError as err: + if err.response["Error"]["Code"] != "ResourceNotFoundException": + raise err + return None + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a new checkpoint in the Bedrock session. + + This method persists checkpoint data and metadata to a Bedrock Agent Runtime session. + It serializes the checkpoint data, creates a session invocation, and saves an invocation + step containing the checkpoint information. + + Args: + config (RunnableConfig): Configuration containing thread_id and checkpoint namespace + checkpoint (Checkpoint): The checkpoint data to store, containing state and channel values + metadata (CheckpointMetadata): Metadata associated with the checkpoint like timestamps + new_versions (ChannelVersions): Version information for communication channels + + Returns: + RunnableConfig: Updated configuration with thread_id, checkpoint_ns and checkpoint_id + """ + session_checkpoint = create_session_checkpoint( + checkpoint, config, metadata, self.serde, new_versions + ) + + # Create session invocation to store checkpoint + checkpoint_invocation_identifier = generate_checkpoint_id( + session_checkpoint.checkpoint_ns + ) + await self._create_session_invocation( + session_checkpoint.thread_id, checkpoint_invocation_identifier + ) + await self._save_invocation_step( + session_checkpoint.thread_id, + checkpoint_invocation_identifier, + session_checkpoint.checkpoint_id, + InvocationStepPayload( + content_blocks=[ + BedrockSessionContentBlock( + text=session_checkpoint.model_dump_json() + ), + ] + ), + ) + + return RunnableConfig( + configurable={ + "thread_id": session_checkpoint.thread_id, + "checkpoint_ns": session_checkpoint.checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + ) + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store write operations in the Bedrock session. + + This method handles storing write operations by: + 1. Creating a new invocation for the writes + 2. Retrieving existing pending writes + 3. Building new content blocks for writes that don't exist + 4. Preserving existing writes that aren't being updated + 5. Saving all content blocks in a new invocation step + + Args: + config (RunnableConfig): Configuration containing thread_id, checkpoint_ns and checkpoint_id + writes (Sequence[tuple[str, Any]]): Sequence of (channel, value) tuples to write + task_id (str): Identifier for the task performing the writes + task_path (str, optional): Path information for the task. Defaults to empty string. + + Returns: + None + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = config["configurable"]["checkpoint_id"] + + # Generate unique identifier for this write operation + writes_invocation_identifier = generate_write_id(checkpoint_ns, checkpoint_id) + + # Create new session invocation + await self._create_session_invocation(thread_id, writes_invocation_identifier) + + # Get existing pending writes for this checkpoint + current_pending_writes = await self._get_checkpoint_pending_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + + content_blocks, new_writes = process_write_operations( + writes, + task_id, + current_pending_writes, + thread_id, + checkpoint_ns, + checkpoint_id, + task_path, + self.serde, + ) + + # Save content blocks if any exist + if content_blocks and new_writes: + await self._save_invocation_step( + thread_id, + writes_invocation_identifier, + None, # Let service generate the step id + InvocationStepPayload(content_blocks=content_blocks), + ) + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints matching the given criteria. + + Args: + config: Optional configuration to filter by + filter: Optional dictionary of filter criteria + before: Optional configuration to get checkpoints before + limit: Optional maximum number of checkpoints to return + + Returns: + AsyncIterator of matching CheckpointTuple objects + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns") + + invocation_identifier = None + + # Get invocation ID only if checkpoint_ns is provided + if checkpoint_ns is not None: + invocation_identifier = generate_checkpoint_id(checkpoint_ns) + + # List all invocation steps with pagination + matching_checkpoints = [] + next_token = None + + while True: + try: + response = await self.session_client.list_invocation_steps( + ListInvocationStepsRequest( + session_identifier=thread_id, + invocation_identifier=invocation_identifier, + next_token=next_token, + ) + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ResourceNotFoundException": + return + else: + raise e + + # Check if there are more pages + next_token = response.next_token + + # Process current page + for step in response.invocation_step_summaries: + if before and step.invocation_step_id >= get_checkpoint_id(before): + continue + + # Get full step details to access metadata + step_detail = await self.session_client.get_invocation_step( + GetInvocationStepRequest( + session_identifier=thread_id, + invocation_identifier=step.invocation_id, + invocation_step_id=step.invocation_step_id, + ) + ) + + payload = json.loads( + step_detail.invocation_step.payload.content_blocks[0].text + ) + + # Append checkpoints and ignore writes + if payload["step_type"] != CHECKPOINT_PREFIX: + continue + + session_checkpoint = SessionCheckpoint(**payload) + + # Apply metadata filter + if filter: + metadata = ( + deserialize_data(self.serde, session_checkpoint.metadata) + if session_checkpoint.metadata + else {} + ) + if not all(metadata.get(k) == v for k, v in filter.items()): + continue + + # Append checkpoints + matching_checkpoints.append(session_checkpoint) + + if limit and len(matching_checkpoints) >= limit: + next_token = None + break + + if next_token is None: + break + + # Yield checkpoint tuples + for checkpoint in matching_checkpoints: + pending_write_ops = await self._get_checkpoint_pending_writes( + thread_id, + checkpoint.checkpoint_ns, + checkpoint.checkpoint_id, + ) + + task_sends = await self._get_task_sends( + thread_id, checkpoint.checkpoint_ns, checkpoint.parent_checkpoint_id + ) + + yield construct_checkpoint_tuple( + thread_id, + checkpoint.checkpoint_ns, + checkpoint, + pending_write_ops, + task_sends, + self.serde, + ) diff --git a/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_session.py b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_session.py new file mode 100644 index 00000000..44dcfd70 --- /dev/null +++ b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/async_session.py @@ -0,0 +1,139 @@ +from typing import Any, Optional + +import boto3 +from botocore.config import Config +from pydantic import SecretStr + +from langgraph_checkpoint_aws.models import ( + CreateInvocationRequest, + CreateInvocationResponse, + CreateSessionRequest, + CreateSessionResponse, + DeleteSessionRequest, + EndSessionRequest, + EndSessionResponse, + GetInvocationStepRequest, + GetInvocationStepResponse, + GetSessionRequest, + GetSessionResponse, + ListInvocationsRequest, + ListInvocationsResponse, + ListInvocationStepsRequest, + ListInvocationStepsResponse, + PutInvocationStepRequest, + PutInvocationStepResponse, +) +from langgraph_checkpoint_aws.utils import ( + process_aws_client_args, + run_boto3_in_executor, + to_boto_params, +) + + +class AsyncBedrockAgentRuntimeSessionClient: + """ + Asynchronous client for AWS Bedrock Agent Runtime API using standard boto3 with async executor. + """ + + def __init__( + self, + region_name: Optional[str] = None, + credentials_profile_name: Optional[str] = None, + aws_access_key_id: Optional[SecretStr] = None, + aws_secret_access_key: Optional[SecretStr] = None, + aws_session_token: Optional[SecretStr] = None, + endpoint_url: Optional[str] = None, + config: Optional[Config] = None, + ): + """ + Initialize AsyncBedrockAgentRuntime with AWS configuration + """ + _session_kwargs, self._client_kwargs = process_aws_client_args( + region_name, + credentials_profile_name, + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + endpoint_url, + config, + ) + + # Create a standard boto3 session + self.session = boto3.Session(**_session_kwargs) + # Pre-create the client to avoid creating it for each operation + self.client = self.session.client( + "bedrock-agent-runtime", **self._client_kwargs + ) + + async def create_session( + self, request: Optional[CreateSessionRequest] = None + ) -> CreateSessionResponse: + """Create a new session asynchronously""" + params = to_boto_params(request) if request else {} + response = await run_boto3_in_executor(self.client.create_session, **params) + return CreateSessionResponse(**response) + + async def get_session(self, request: GetSessionRequest) -> GetSessionResponse: + """Get details of an existing session asynchronously""" + response = await run_boto3_in_executor( + self.client.get_session, **to_boto_params(request) + ) + return GetSessionResponse(**response) + + async def end_session(self, request: EndSessionRequest) -> EndSessionResponse: + """End an existing session asynchronously""" + response = await run_boto3_in_executor( + self.client.end_session, **to_boto_params(request) + ) + return EndSessionResponse(**response) + + async def delete_session(self, request: DeleteSessionRequest) -> None: + """Delete an existing session asynchronously""" + await run_boto3_in_executor( + self.client.delete_session, **to_boto_params(request) + ) + + async def create_invocation( + self, request: CreateInvocationRequest + ) -> CreateInvocationResponse: + """Create a new invocation asynchronously""" + response = await run_boto3_in_executor( + self.client.create_invocation, **to_boto_params(request) + ) + return CreateInvocationResponse(**response) + + async def list_invocations( + self, request: ListInvocationsRequest + ) -> ListInvocationsResponse: + """List invocations for a session asynchronously""" + response = await run_boto3_in_executor( + self.client.list_invocations, **to_boto_params(request) + ) + return ListInvocationsResponse(**response) + + async def put_invocation_step( + self, request: PutInvocationStepRequest + ) -> PutInvocationStepResponse: + """Put a step in an invocation asynchronously""" + response = await run_boto3_in_executor( + self.client.put_invocation_step, **to_boto_params(request) + ) + return PutInvocationStepResponse(**response) + + async def get_invocation_step( + self, request: GetInvocationStepRequest + ) -> GetInvocationStepResponse: + """Get a step in an invocation asynchronously""" + response = await run_boto3_in_executor( + self.client.get_invocation_step, **to_boto_params(request) + ) + return GetInvocationStepResponse(**response) + + async def list_invocation_steps( + self, request: ListInvocationStepsRequest + ) -> ListInvocationStepsResponse: + """List steps in an invocation asynchronously""" + response = await run_boto3_in_executor( + self.client.list_invocation_steps, **to_boto_params(request) + ) + return ListInvocationStepsResponse(**response) diff --git a/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/utils.py b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/utils.py index e09d9d5c..44028d99 100644 --- a/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/utils.py +++ b/libs/langgraph-checkpoint-aws/langgraph_checkpoint_aws/utils.py @@ -1,9 +1,12 @@ +import asyncio import base64 import hashlib import json import uuid from collections.abc import Sequence -from typing import Any, Optional, Tuple, Union, cast +from contextvars import copy_context +from functools import partial +from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast from botocore.config import Config from langchain_core.runnables import RunnableConfig @@ -26,6 +29,8 @@ SessionPendingWrite, ) +T = TypeVar("T") + def to_boto_params(model: BaseModel) -> dict: """ @@ -444,3 +449,15 @@ def create_client_config(config: Optional[Config] = None) -> Config: new_user_agent = f"{existing_user_agent} md/sdk_user_agent/{SDK_USER_AGENT}".strip() return Config(user_agent_extra=new_user_agent, **config_kwargs) + + +async def run_boto3_in_executor(func: Callable[..., T], *args: Any, **kwargs: Any) -> T: + """Run a boto3 function in an executor to prevent blocking the event loop.""" + + return await asyncio.get_running_loop().run_in_executor( + None, + cast( + "Callable[..., T]", + partial(copy_context().run, lambda: func(*args, **kwargs)), + ), + ) diff --git a/libs/langgraph-checkpoint-aws/poetry.lock b/libs/langgraph-checkpoint-aws/poetry.lock index c78b5cb5..3c136fcf 100644 --- a/libs/langgraph-checkpoint-aws/poetry.lock +++ b/libs/langgraph-checkpoint-aws/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "annotated-types" @@ -1204,6 +1204,26 @@ tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.26.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.9" +groups = ["test"] +files = [ + {file = "pytest_asyncio-0.26.0-py3-none-any.whl", hash = "sha256:7b51ed894f4fbea1340262bdae5135797ebbe21d8638978e35d31c6d19f72fb0"}, + {file = "pytest_asyncio-0.26.0.tar.gz", hash = "sha256:c4df2a697648241ff39e7f0e4a73050b03f123f760673956cf0d72a4990e312f"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" +typing-extensions = {version = ">=4.12", markers = "python_version < \"3.10\""} + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "6.0.0" @@ -1473,11 +1493,12 @@ version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" -groups = ["main", "dev", "test_integration", "typing"] +groups = ["main", "dev", "test", "test_integration", "typing"] files = [ {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +markers = {test = "python_version == \"3.9\""} [[package]] name = "urllib3" @@ -1486,7 +1507,7 @@ description = "HTTP library with thread-safe connection pooling, file post, and optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" groups = ["main", "test_integration"] -markers = "python_version < \"3.10\"" +markers = "python_version == \"3.9\"" files = [ {file = "urllib3-1.26.20-py2.py3-none-any.whl", hash = "sha256:0ed14ccfbf1c30a9072c7ca157e4319b70d65f623e91e7b32fadb2853431016e"}, {file = "urllib3-1.26.20.tar.gz", hash = "sha256:40c2dc0c681e47eb8f90e7e27bf6ff7df2e677421fd46756da1161c39ca70d32"}, @@ -1632,4 +1653,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.9,<4.0" -content-hash = "ec4602a5086ce90702d1b44800851a755d5ecd64973d1d4d0164c1d50f030967" +content-hash = "fd83ddf3fafbbf9e0a2380ed8a60b4e87425e9973f41c0f0151e6e778fa58a5e" diff --git a/libs/langgraph-checkpoint-aws/pyproject.toml b/libs/langgraph-checkpoint-aws/pyproject.toml index 8faefe6f..60cc0f66 100644 --- a/libs/langgraph-checkpoint-aws/pyproject.toml +++ b/libs/langgraph-checkpoint-aws/pyproject.toml @@ -35,10 +35,14 @@ optional = true [tool.poetry.group.test.dependencies] pytest = ">=7.4.3" pytest-cov = ">=4.1.0" +pytest-asyncio = ">=0.26.0" [tool.poetry.group.test_integration] optional = true +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" + [tool.poetry.group.test_integration.dependencies] langchain-aws = ">=0.2.14" diff --git a/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_async_saver.py b/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_async_saver.py new file mode 100644 index 00000000..faef590f --- /dev/null +++ b/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_async_saver.py @@ -0,0 +1,140 @@ +import datetime +from typing import Literal + +import pytest +from langchain_aws import ChatBedrock +from langchain_core.tools import tool +from langgraph.checkpoint.base import Checkpoint, uuid6 +from langgraph.prebuilt import create_react_agent + +from langgraph_checkpoint_aws.async_saver import AsyncBedrockSessionSaver +from langgraph_checkpoint_aws.models import DeleteSessionRequest, EndSessionRequest + + +@tool +def get_weather(city: Literal["nyc", "sf"]): + """Use this to get weather information.""" + if city == "nyc": + return "It might be cloudy in nyc" + elif city == "sf": + return "It's always sunny in sf" + else: + raise AssertionError("Unknown city") + + +class TestAsyncBedrockMemorySaver: + @pytest.fixture + def tools(self): + # Setup tools + return [get_weather] + + @pytest.fixture + def model(self): + # Setup model + return ChatBedrock( + model="anthropic.claude-3-sonnet-20240229-v1:0", region="us-west-2" + ) + + @pytest.fixture + def session_saver(self): + # Return the instantiated object + return AsyncBedrockSessionSaver(region_name="us-west-2") + + @pytest.fixture + def boto_session_client(self, session_saver): + # Return the async client wrapper + return session_saver.session_client + + @pytest.mark.asyncio + async def test_weather_tool_responses(self): + # Test weather tool directly + assert get_weather.invoke("sf") == "It's always sunny in sf" + assert get_weather.invoke("nyc") == "It might be cloudy in nyc" + + @pytest.mark.asyncio + async def test_checkpoint_save_and_retrieve( + self, boto_session_client, session_saver + ): + # Create session + session_response = await boto_session_client.create_session() + session_id = session_response.session_id + assert session_id, "Session ID should not be empty" + + config = {"configurable": {"thread_id": session_id, "checkpoint_ns": ""}} + checkpoint = Checkpoint( + v=1, + id=str(uuid6(clock_seq=-2)), + ts=datetime.datetime.now(datetime.timezone.utc).isoformat(), + channel_values={"key": "value"}, + channel_versions={}, + versions_seen={}, + pending_sends=[], + ) + checkpoint_metadata = {"source": "input", "step": 1, "writes": {"key": "value"}} + + try: + saved_config = await session_saver.aput( + config, + checkpoint, + checkpoint_metadata, + {}, + ) + assert saved_config == { + "configurable": { + "checkpoint_id": checkpoint["id"], + "checkpoint_ns": "", + "thread_id": session_id, + } + } + + checkpoint_tuple = await session_saver.aget_tuple(saved_config) + assert checkpoint_tuple.checkpoint == checkpoint + assert checkpoint_tuple.metadata == checkpoint_metadata + assert checkpoint_tuple.config == saved_config + + finally: + # Create proper request objects + await boto_session_client.end_session( + EndSessionRequest(session_identifier=session_id) + ) + await boto_session_client.delete_session( + DeleteSessionRequest(session_identifier=session_id) + ) + + @pytest.mark.asyncio + async def test_weather_query_and_checkpointing( + self, boto_session_client, tools, model, session_saver + ): + # Create session + session_response = await boto_session_client.create_session() + session_id = session_response.session_id + assert session_id, "Session ID should not be empty" + try: + # Create graph and config + graph = create_react_agent(model, tools=tools, checkpointer=session_saver) + config = {"configurable": {"thread_id": session_id}} + + # Test weather query + response = await graph.ainvoke( + {"messages": [("human", "what's the weather in sf")]}, config + ) + assert response, "Response should not be empty" + + # Test checkpoint retrieval + checkpoint = await session_saver.aget(config) + assert checkpoint, "Checkpoint should not be empty" + + # Test checkpoint listing + checkpoint_tuples = [tup async for tup in session_saver.alist(config)] + assert checkpoint_tuples, "Checkpoint tuples should not be empty" + assert isinstance(checkpoint_tuples, list), ( + "Checkpoint tuples should be a list" + ) + finally: + # Create proper request objects + await boto_session_client.end_session( + EndSessionRequest(session_identifier=session_id) + ) + await boto_session_client.delete_session( + DeleteSessionRequest(session_identifier=session_id) + ) diff --git a/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_saver.py b/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_saver.py index 8f54ff00..e2dab0d7 100644 --- a/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_saver.py +++ b/libs/langgraph-checkpoint-aws/tests/integration_tests/saver/test_saver.py @@ -113,9 +113,9 @@ def test_weather_query_and_checkpointing( # Test checkpoint listing checkpoint_tuples = list(session_saver.list(config)) assert checkpoint_tuples, "Checkpoint tuples should not be empty" - assert isinstance( - checkpoint_tuples, list - ), "Checkpoint tuples should be a list" + assert isinstance(checkpoint_tuples, list), ( + "Checkpoint tuples should be a list" + ) finally: boto_session_client.end_session(sessionIdentifier=session_id) boto_session_client.delete_session(sessionIdentifier=session_id) diff --git a/libs/langgraph-checkpoint-aws/tests/unit_tests/conftest.py b/libs/langgraph-checkpoint-aws/tests/unit_tests/conftest.py index f3980e28..a0f1ed34 100644 --- a/libs/langgraph-checkpoint-aws/tests/unit_tests/conftest.py +++ b/libs/langgraph-checkpoint-aws/tests/unit_tests/conftest.py @@ -1,7 +1,7 @@ import base64 import datetime import json -from unittest.mock import MagicMock, Mock +from unittest.mock import AsyncMock, MagicMock, Mock from uuid import uuid4 import pytest @@ -228,7 +228,7 @@ def sample_session_checkpoint(sample_invocation_step_summary): thread_id=sample_invocation_step_summary["sessionId"], checkpoint_ns=sample_invocation_step_summary["invocationId"], checkpoint_id=sample_invocation_step_summary["invocationStepId"], - checkpoint={}, + checkpoint=("json", b"e30="), metadata=json.dumps({"key": "value"}), parent_checkpoint_id=None, channel_values={}, diff --git a/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_saver.py b/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_saver.py new file mode 100644 index 00000000..3e77bd0c --- /dev/null +++ b/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_saver.py @@ -0,0 +1,991 @@ +import datetime +import json +from unittest.mock import ANY, AsyncMock, Mock, patch + +import pytest +from botocore.config import Config +from botocore.exceptions import ClientError +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import CheckpointTuple +from langgraph.constants import ERROR +from pydantic import SecretStr + +from langgraph_checkpoint_aws.async_saver import ( + AsyncBedrockAgentRuntimeSessionClient, + AsyncBedrockSessionSaver, +) +from langgraph_checkpoint_aws.models import ( + GetInvocationStepResponse, + InvocationStep, + ListInvocationStepsResponse, +) + + +class TestAsyncBedrockSessionSaver: + @pytest.fixture + def session_saver(self, mock_boto_client): + with patch("boto3.Session") as mock_aioboto_session: + mock_aioboto_session.return_value.client.return_value = mock_boto_client + yield AsyncBedrockSessionSaver() + + @pytest.fixture + def runnable_config(self): + return RunnableConfig( + configurable={ + "thread_id": "test_thread_id", + "checkpoint_ns": "test_namespace", + } + ) + + @pytest.mark.asyncio + async def test__create_session_invocation_success( + self, mock_boto_client, session_saver, sample_create_invocation_response + ): + # Arrange + thread_id = "test_thread_id" + invocation_id = "test_invocation_id" + mock_boto_client.create_invocation.return_value = ( + sample_create_invocation_response + ) + + # Act + await session_saver._create_session_invocation(thread_id, invocation_id) + + # Assert + mock_boto_client.create_invocation.assert_called_once() + + @pytest.mark.asyncio + async def test__create_session_invocation_conflict( + self, mock_boto_client, session_saver + ): + # Arrange + error_response = {"Error": {"Code": "ConflictException", "Message": "Conflict"}} + mock_boto_client.create_invocation.side_effect = ClientError( + error_response=error_response, + operation_name="CreateInvocation", + ) + thread_id = "test_thread_id" + invocation_id = "test_invocation_id" + + # Act - should not raise an exception + await session_saver._create_session_invocation(thread_id, invocation_id) + + # Assert + mock_boto_client.create_invocation.assert_called_once() + + @pytest.mark.asyncio + async def test__create_session_invocation_raises_error( + self, mock_boto_client, session_saver + ): + # Arrange + thread_id = "test_thread_id" + invocation_id = "test_invocation_id" + + error_response = {"Error": {"Code": "SomeOtherError", "Message": "Other error"}} + mock_boto_client.create_invocation.side_effect = ClientError( + error_response=error_response, + operation_name="CreateInvocation", + ) + + # Act & Assert + with pytest.raises(ClientError) as exc_info: + await session_saver._create_session_invocation(thread_id, invocation_id) + + assert exc_info.value.response["Error"]["Code"] == "SomeOtherError" + mock_boto_client.create_invocation.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_pending_writes_success( + self, + mock_boto_client, + session_saver, + sample_session_pending_write, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread" + checkpoint_ns = "test_namespace" + checkpoint_id = "test_checkpoint" + + # serialize payload + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_pending_write.model_dump_json() + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + mock_boto_client.get_invocation_step.return_value = ( + sample_get_invocation_step_response + ) + + # Act + result = await session_saver._get_checkpoint_pending_writes( + thread_id, checkpoint_ns, checkpoint_id + ) + + # Assert + assert len(result) == 1 + mock_boto_client.list_invocation_steps.assert_called_once() + mock_boto_client.get_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_pending_writes_no_invocation_steps( + self, + mock_boto_client, + session_saver, + sample_list_invocation_steps_response, + ): + # Arrange + sample_list_invocation_steps_response["invocationStepSummaries"] = [] + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + + # Act + result = await session_saver._get_checkpoint_pending_writes( + "thread_id", "ns", "checkpoint_id" + ) + + # Assert + assert result == [] + mock_boto_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_pending_writes_resource_not_found( + self, mock_boto_client, session_saver + ): + # Arrange + error_response = { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Resource not found", + } + } + mock_boto_client.list_invocation_steps.side_effect = ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + + # Act + result = await session_saver._get_checkpoint_pending_writes( + "thread_id", "ns", "checkpoint_id" + ) + + # Assert + assert result == [] + mock_boto_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_pending_writes_client_error( + self, mock_boto_client, session_saver, sample_invocation_step_payload + ): + # Arrange + error_response = {"Error": {"Code": "SomeError", "Message": "Error occurred"}} + mock_boto_client.list_invocation_steps.side_effect = ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + + # Act & Assert + with pytest.raises(ClientError): + await session_saver._get_checkpoint_pending_writes( + "thread_id", "ns", "checkpoint_id" + ) + + mock_boto_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test__save_invocation_step_success( + self, + mock_boto_client, + session_saver, + sample_invocation_step_payload, + sample_put_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + invocation_identifier = "test_invocation_identifier" + invocation_step_id = "test_invocation_step_id" + mock_boto_client.put_invocation_step.return_value = ( + sample_put_invocation_step_response + ) + + # Act + with patch("datetime.datetime") as mock_datetime: + invocation_step_time = datetime.datetime.now(datetime.timezone.utc) + mock_datetime.now.return_value = invocation_step_time + await session_saver._save_invocation_step( + thread_id, + invocation_identifier, + invocation_step_id, + sample_invocation_step_payload, + ) + + # Assert + mock_boto_client.put_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__save_invocation_step_client_error( + self, mock_boto_client, session_saver, sample_invocation_step_payload + ): + # Arrange + error_response = {"Error": {"Code": "SomeError", "Message": "Error occurred"}} + mock_boto_client.put_invocation_step.side_effect = ClientError( + error_response=error_response, + operation_name="PutInvocationStep", + ) + + # Act & Assert + with pytest.raises(ClientError): + await session_saver._save_invocation_step( + "thread_id", "inv_id", "step_id", sample_invocation_step_payload + ) + + mock_boto_client.put_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__find_most_recent_checkpoint_step_success( + self, + mock_boto_client, + session_saver, + sample_session_checkpoint, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + + # serialize payload + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + mock_boto_client.get_invocation_step.return_value = ( + sample_get_invocation_step_response + ) + + # Act + result = await session_saver._find_most_recent_checkpoint_step( + thread_id, checkpoint_ns + ) + + # Assert + assert result is not None + mock_boto_client.list_invocation_steps.assert_called_once() + mock_boto_client.get_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__find_most_recent_checkpoint_step_skips_writes( + self, + mock_boto_client, + session_saver, + sample_session_pending_write, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + + # serialize payload + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_pending_write.model_dump_json() + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + mock_boto_client.get_invocation_step.return_value = ( + sample_get_invocation_step_response + ) + + # Act + result = await session_saver._find_most_recent_checkpoint_step( + thread_id, checkpoint_ns + ) + + # Assert + assert result is None + mock_boto_client.list_invocation_steps.assert_called_once() + mock_boto_client.get_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__find_most_recent_checkpoint_step_no_invocation_steps( + self, + mock_boto_client, + session_saver, + sample_list_invocation_steps_response, + ): + # Arrange + sample_list_invocation_steps_response["invocationStepSummaries"] = [] + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + + # Act + result = await session_saver._find_most_recent_checkpoint_step( + "thread_id", "ns" + ) + + # Assert + assert result is None + mock_boto_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_step_with_checkpoint_id( + self, + mock_boto_client, + session_saver, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + checkpoint_id = "test_checkpoint_id" + session_saver._find_most_recent_checkpoint_step = Mock() + mock_boto_client.get_invocation_step.return_value = ( + sample_get_invocation_step_response + ) + + # Act + await session_saver._get_checkpoint_step( + thread_id, checkpoint_ns, checkpoint_id + ) + + # Assert + session_saver._find_most_recent_checkpoint_step.assert_not_called() + mock_boto_client.get_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test__get_checkpoint_step_without_checkpoint_id( + self, + mock_boto_client, + session_saver, + sample_invocation_step_payload, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + session_saver._find_most_recent_checkpoint_step = AsyncMock( + return_value=sample_invocation_step_payload + ) + + # Act + result = await session_saver._get_checkpoint_step(thread_id, checkpoint_ns) + + # Assert + assert result == sample_invocation_step_payload + session_saver._find_most_recent_checkpoint_step.assert_called_once_with( + thread_id, + checkpoint_ns, + ) + mock_boto_client.get_invocation_step.assert_not_called() + + @pytest.mark.asyncio + async def test__get_checkpoint_step_empty_without_checkpoint_id( + self, + mock_boto_client, + session_saver, + sample_invocation_step_payload, + sample_get_invocation_step_response, + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + session_saver._find_most_recent_checkpoint_step = AsyncMock(return_value=None) + + # Act + result = await session_saver._get_checkpoint_step(thread_id, checkpoint_ns) + + # Assert + assert result is None + session_saver._find_most_recent_checkpoint_step.assert_called_once_with( + thread_id, + checkpoint_ns, + ) + mock_boto_client.get_invocation_step.assert_not_called() + + @pytest.mark.asyncio + async def test__get_task_sends_without_parent_checkpoint_id( + self, session_saver, sample_session_checkpoint + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + + # Act + result = await session_saver._get_task_sends(thread_id, checkpoint_ns, None) + + # Assert + assert result == [] + + @pytest.mark.asyncio + async def test__get_task_sends( + self, session_saver, sample_session_pending_write_with_sends + ): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + parent_checkpoint_id = "test_parent_checkpoint_id" + + session_saver._get_checkpoint_pending_writes = AsyncMock( + return_value=sample_session_pending_write_with_sends + ) + + # Act + result = await session_saver._get_task_sends( + thread_id, checkpoint_ns, parent_checkpoint_id + ) + + # Assert + assert result == [ + ["2", "__pregel_tasks", ["json", b"eyJrMiI6ICJ2MiJ9"], "/test2/path2", 1], + ["3", "__pregel_tasks", ["json", b"eyJrMyI6ICJ2MyJ9"], "/test3/path3", 1], + ] + session_saver._get_checkpoint_pending_writes.assert_called_once_with( + thread_id, checkpoint_ns, parent_checkpoint_id + ) + + @pytest.mark.asyncio + async def test__get_task_sends_empty(self, session_saver): + # Arrange + thread_id = "test_thread_id" + checkpoint_ns = "test_namespace" + parent_checkpoint_id = "test_parent_checkpoint_id" + + session_saver._get_checkpoint_pending_writes = AsyncMock(return_value=[]) + + # Act + result = await session_saver._get_task_sends( + thread_id, checkpoint_ns, parent_checkpoint_id + ) + + # Assert + assert result == [] + session_saver._get_checkpoint_pending_writes.assert_called_once_with( + thread_id, checkpoint_ns, parent_checkpoint_id + ) + + @pytest.mark.asyncio + @patch("langgraph_checkpoint_aws.async_saver.construct_checkpoint_tuple") + async def test_aget_tuple_success( + self, + mock_construct_checkpoint, + session_saver, + runnable_config, + sample_get_invocation_step_response, + sample_session_pending_write_with_sends, + sample_session_checkpoint, + ): + # Arrange + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver._get_checkpoint_step = AsyncMock( + return_value=InvocationStep( + **sample_get_invocation_step_response["invocationStep"] + ) + ) + session_saver._get_checkpoint_pending_writes = AsyncMock( + return_value=sample_session_pending_write_with_sends + ) + session_saver._get_task_sends = AsyncMock(return_value=[]) + mock_construct_checkpoint.return_value = AsyncMock(spec=CheckpointTuple) + + # Act + result = await session_saver.aget_tuple(runnable_config) + + # Assert + assert isinstance(result, CheckpointTuple) + + @pytest.mark.asyncio + async def test_aget_tuple_success_empty(self, session_saver, runnable_config): + # Arrange + session_saver._get_checkpoint_step = AsyncMock(return_value=None) + + # Act + result = await session_saver.aget_tuple(runnable_config) + + # Assert + assert result is None + session_saver._get_checkpoint_step.assert_called_once() + + @pytest.mark.asyncio + async def test_aget_tuple_resource_not_found_error( + self, session_saver, runnable_config + ): + # Arrange + error_response = { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Resource not found", + } + } + session_saver._get_checkpoint_step = AsyncMock( + side_effect=ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + ) + + # Act + result = await session_saver.aget_tuple(runnable_config) + + # Assert + assert result is None + session_saver._get_checkpoint_step.assert_called_once() + + @pytest.mark.asyncio + async def test_aget_tuple_error(self, session_saver, runnable_config): + # Arrange + error_response = { + "Error": {"Code": "SomeOtherError", "Message": "Some other error"} + } + session_saver._get_checkpoint_step = AsyncMock( + side_effect=ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + ) + + # Act and Assert + with pytest.raises(ClientError): + await session_saver.aget_tuple(runnable_config) + + session_saver._get_checkpoint_step.assert_called_once() + + @pytest.mark.asyncio + async def test_aput_success( + self, + session_saver, + runnable_config, + sample_checkpoint, + sample_checkpoint_metadata, + ): + # Arrange + session_saver._create_session_invocation = AsyncMock() + session_saver._save_invocation_step = AsyncMock() + + # Act + await session_saver.aput( + runnable_config, sample_checkpoint, sample_checkpoint_metadata, {} + ) + + # Assert + session_saver._create_session_invocation.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "72f4457f-e6bb-e1db-49ee-06cd9901904f", + ) + session_saver._save_invocation_step.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "72f4457f-e6bb-e1db-49ee-06cd9901904f", + "checkpoint_123", + ANY, + ) + + @pytest.mark.asyncio + async def test_aput_writes_success( + self, + session_saver, + runnable_config, + sample_checkpoint, + sample_checkpoint_metadata, + ): + # Arrange + task_id = "test_task_id" + task_path = "test_task_path" + writes = [("__pregel_pull", "__start__"), ("__pregel_pull", "add_one")] + runnable_config["configurable"]["checkpoint_id"] = "test_checkpoint_id" + + session_saver._create_session_invocation = AsyncMock() + session_saver._save_invocation_step = AsyncMock() + session_saver._get_checkpoint_pending_writes = AsyncMock(return_value=[]) + + # Act + await session_saver.aput_writes(runnable_config, writes, task_id, task_path) + + # Assert + session_saver._create_session_invocation.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "ea473b95-7b9c-fe52-df2c-3a7353d3148b", + ) + session_saver._save_invocation_step.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "ea473b95-7b9c-fe52-df2c-3a7353d3148b", + None, + ANY, + ) + + @pytest.mark.asyncio + async def test_aput_writes_skip_existing_writes( + self, + session_saver, + runnable_config, + sample_checkpoint, + sample_checkpoint_metadata, + sample_session_pending_write, + ): + # Arrange + task_id = "test_task_id" + task_path = "test_task_path" + writes = [("__pregel_pull", "__start__")] + runnable_config["configurable"]["checkpoint_id"] = "test_checkpoint_id" + + session_saver._create_session_invocation = AsyncMock() + session_saver._save_invocation_step = AsyncMock() + + sample_session_pending_write.task_id = task_id + sample_session_pending_write.write_idx = 0 + + session_saver._get_checkpoint_pending_writes = AsyncMock( + return_value=[sample_session_pending_write] + ) + + # Act + await session_saver.aput_writes(runnable_config, writes, task_id, task_path) + + # Assert + session_saver._create_session_invocation.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "ea473b95-7b9c-fe52-df2c-3a7353d3148b", + ) + session_saver._save_invocation_step.assert_not_called() + + @pytest.mark.asyncio + async def test_aput_writes_override_existing_writes( + self, + session_saver, + runnable_config, + sample_checkpoint, + sample_checkpoint_metadata, + sample_session_pending_write, + ): + # Arrange + task_id = "test_task_id" + task_path = "test_task_path" + writes = [(ERROR, "__start__")] + runnable_config["configurable"]["checkpoint_id"] = "test_checkpoint_id" + + session_saver._create_session_invocation = AsyncMock() + session_saver._save_invocation_step = AsyncMock() + + sample_session_pending_write.task_id = task_id + sample_session_pending_write.write_idx = 0 + + session_saver._get_checkpoint_pending_writes = AsyncMock( + return_value=[sample_session_pending_write] + ) + + # Act + await session_saver.aput_writes(runnable_config, writes, task_id, task_path) + + # Assert + session_saver._create_session_invocation.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "ea473b95-7b9c-fe52-df2c-3a7353d3148b", + ) + session_saver._save_invocation_step.assert_called_once_with( + runnable_config["configurable"]["thread_id"], + "ea473b95-7b9c-fe52-df2c-3a7353d3148b", + None, + ANY, + ) + + @pytest.mark.asyncio + @patch("langgraph_checkpoint_aws.async_saver.construct_checkpoint_tuple") + async def test_alist_success( + self, + mock_construct_checkpoint, + session_saver, + runnable_config, + sample_session_checkpoint, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver.session_client.get_invocation_step = AsyncMock( + return_value=GetInvocationStepResponse( + **sample_get_invocation_step_response + ) + ) + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse( + **sample_list_invocation_steps_response + ) + ) + session_saver._get_checkpoint_pending_writes = AsyncMock(return_value=[]) + session_saver._get_task_sends = AsyncMock(return_value=[]) + mock_construct_checkpoint.return_value = AsyncMock(spec=CheckpointTuple) + + # Act + result = [ + checkpoint async for checkpoint in session_saver.alist(runnable_config) + ] + + # Assert + assert len(list(result)) == 1 + + @pytest.mark.asyncio + async def test_alist_skips_writes( + self, + session_saver, + runnable_config, + sample_session_pending_write, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_pending_write.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver.session_client.get_invocation_step = AsyncMock( + return_value=GetInvocationStepResponse( + **sample_get_invocation_step_response + ) + ) + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse( + **sample_list_invocation_steps_response + ) + ) + + # Act + result = [ + checkpoint async for checkpoint in session_saver.alist(runnable_config) + ] + + # Assert + assert len(list(result)) == 0 + + @pytest.mark.asyncio + @patch("langgraph_checkpoint_aws.async_saver.construct_checkpoint_tuple") + async def test_alist_with_limit( + self, + mock_construct_checkpoint, + session_saver, + runnable_config, + sample_session_checkpoint, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver.session_client.get_invocation_step = AsyncMock( + return_value=GetInvocationStepResponse( + **sample_get_invocation_step_response + ) + ) + # Duplicate list response + sample_list_invocation_steps_response["invocationStepSummaries"] *= 10 + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse( + **sample_list_invocation_steps_response + ) + ) + session_saver._get_checkpoint_pending_writes = AsyncMock(return_value=[]) + session_saver._get_task_sends = AsyncMock(return_value=[]) + mock_construct_checkpoint.return_value = AsyncMock(spec=CheckpointTuple) + + # Act + result = [ + checkpoint + async for checkpoint in session_saver.alist(runnable_config, limit=3) + ] + + # Assert + assert len(list(result)) == 3 + + @pytest.mark.asyncio + async def test_alist_with_filter( + self, + session_saver, + runnable_config, + sample_session_checkpoint, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver.session_client.get_invocation_step = AsyncMock( + return_value=GetInvocationStepResponse( + **sample_get_invocation_step_response + ) + ) + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse( + **sample_list_invocation_steps_response + ) + ) + session_saver._get_checkpoint_pending_writes = AsyncMock(return_value=[]) + session_saver._get_task_sends = AsyncMock(return_value=[]) + session_saver._construct_checkpoint_tuple = AsyncMock( + return_value=AsyncMock(spec=CheckpointTuple) + ) + + # Act + result = [ + checkpoint + async for checkpoint in session_saver.alist( + runnable_config, filter={"key": "value1"} + ) + ] + + # Assert + assert len(list(result)) == 0 + + @pytest.mark.asyncio + async def test_alist_with_before( + self, + session_saver, + runnable_config, + sample_session_checkpoint, + sample_list_invocation_steps_response, + sample_get_invocation_step_response, + ): + # Arrange + before = RunnableConfig( + configurable={ + "checkpoint_id": sample_get_invocation_step_response["invocationStep"][ + "invocationStepId" + ] + } + ) + sample_session_checkpoint.metadata = json.dumps( + sample_session_checkpoint.metadata + ) + sample_get_invocation_step_response["invocationStep"]["payload"][ + "contentBlocks" + ][0]["text"] = sample_session_checkpoint.model_dump_json() + + # Mock all required internal methods + session_saver._generate_checkpoint_id = AsyncMock( + return_value="test_checkpoint_id" + ) + session_saver.session_client.get_invocation_step = AsyncMock( + return_value=GetInvocationStepResponse( + **sample_get_invocation_step_response + ) + ) + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse( + **sample_list_invocation_steps_response + ) + ) + + # Act + result = [ + checkpoint + async for checkpoint in session_saver.alist(runnable_config, before=before) + ] + + # Assert + assert len(list(result)) == 0 + + @pytest.mark.asyncio + async def test_alist_empty_response( + self, + session_saver, + runnable_config, + ): + # Arrange + session_saver.session_client.list_invocation_steps = AsyncMock( + return_value=ListInvocationStepsResponse(invocation_step_summaries=[]) + ) + + # Act + result = [ + checkpoint async for checkpoint in session_saver.alist(runnable_config) + ] + + # Assert + assert len(result) == 0 + session_saver.session_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test_alist_returns_empty_on_resource_not_found( + self, + session_saver, + runnable_config, + ): + # Arrange + error_response = { + "Error": { + "Code": "ResourceNotFoundException", + "Message": "Resource not found", + } + } + session_saver.session_client.list_invocation_steps = AsyncMock( + side_effect=ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + ) + + # Act + result = [ + checkpoint async for checkpoint in session_saver.alist(runnable_config) + ] + + # Assert + assert len(result) == 0 + session_saver.session_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test_alist_error( + self, + session_saver, + runnable_config, + ): + # Arrange + error_response = { + "Error": {"Code": "SomeOtherError", "Message": "Some other error"} + } + session_saver.session_client.list_invocation_steps = AsyncMock( + side_effect=ClientError( + error_response=error_response, + operation_name="ListInvocationSteps", + ) + ) + + # Act and Assert + with pytest.raises(ClientError): + async for _ in session_saver.alist(runnable_config): + pass + + session_saver.session_client.list_invocation_steps.assert_called_once() diff --git a/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_session.py b/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_session.py new file mode 100644 index 00000000..b7e98d0a --- /dev/null +++ b/libs/langgraph-checkpoint-aws/tests/unit_tests/test_async_session.py @@ -0,0 +1,302 @@ +from unittest.mock import patch + +import pytest + +from langgraph_checkpoint_aws.async_saver import AsyncBedrockAgentRuntimeSessionClient +from langgraph_checkpoint_aws.models import ( + CreateInvocationRequest, + CreateInvocationResponse, + CreateSessionRequest, + CreateSessionResponse, + DeleteSessionRequest, + EndSessionRequest, + EndSessionResponse, + GetInvocationStepRequest, + GetInvocationStepResponse, + GetSessionRequest, + GetSessionResponse, + ListInvocationsRequest, + ListInvocationsResponse, + ListInvocationStepsRequest, + ListInvocationStepsResponse, + PutInvocationStepRequest, + PutInvocationStepResponse, +) + + +class TestAsyncBedrockAgentRuntimeSessionClient: + @pytest.fixture + def mock_session_client(self, mock_boto_client): + with patch("boto3.Session") as mock_aioboto_session: + mock_aioboto_session.return_value.client.return_value = mock_boto_client + yield AsyncBedrockAgentRuntimeSessionClient() + + class TestSession: + @pytest.mark.asyncio + async def test_create_async_session( + self, mock_session_client, mock_boto_client, sample_create_session_response + ): + # Arrange + mock_boto_client.create_session.return_value = ( + sample_create_session_response + ) + request = CreateSessionRequest() + + # Act + response = await mock_session_client.create_session(request) + + # Assert + assert isinstance(response, CreateSessionResponse) + mock_boto_client.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_with_user_attr( + self, mock_session_client, mock_boto_client, sample_create_session_response + ): + # Arrange + mock_boto_client.create_session.return_value = ( + sample_create_session_response + ) + request = CreateSessionRequest( + session_metadata={"key": "value"}, + encryption_key_arn="test-arn", + tags={"tag1": "value1"}, + ) + + # Act + response = await mock_session_client.create_session(request) + + # Assert + assert isinstance(response, CreateSessionResponse) + mock_boto_client.create_session.assert_called_once() + + @pytest.mark.asyncio + async def test_get_session( + self, + mock_session_client, + mock_boto_client, + sample_get_session_response, + sample_session_id, + ): + # Arrange + mock_boto_client.get_session.return_value = sample_get_session_response + request = GetSessionRequest(session_identifier=sample_session_id) + + # Act + response = await mock_session_client.get_session(request) + + # Assert + assert isinstance(response, GetSessionResponse) + mock_boto_client.get_session.assert_called_once() + + @pytest.mark.asyncio + async def test_end_session( + self, + mock_session_client, + mock_boto_client, + sample_get_session_response, + sample_session_id, + ): + # Arrange + mock_boto_client.end_session.return_value = sample_get_session_response + request = EndSessionRequest(session_identifier=sample_session_id) + + # Act + response = await mock_session_client.end_session(request) + + # Assert + assert isinstance(response, EndSessionResponse) + mock_boto_client.end_session.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_session( + self, mock_session_client, mock_boto_client, sample_session_id + ): + # Arrange + request = DeleteSessionRequest(session_identifier=sample_session_id) + + # Act + await mock_session_client.delete_session(request) + + # Assert + mock_boto_client.delete_session.assert_called_once() + + class TestInvocation: + @pytest.mark.asyncio + async def test_create_invocation( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_create_invocation_response, + ): + # Arrange + mock_boto_client.create_invocation.return_value = ( + sample_create_invocation_response + ) + request = CreateInvocationRequest(session_identifier=sample_session_id) + + # Act + response = await mock_session_client.create_invocation(request) + + # Assert + assert isinstance(response, CreateInvocationResponse) + mock_boto_client.create_invocation.assert_called_once() + + @pytest.mark.asyncio + async def test_create_invocation_with_user_attr( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_invocation_id, + sample_create_invocation_response, + ): + # Arrange + mock_boto_client.create_invocation.return_value = ( + sample_create_invocation_response + ) + request = CreateInvocationRequest( + session_identifier=sample_session_id, + invocation_id=sample_invocation_id, + description="Test invocation description", + ) + + # Act + response = await mock_session_client.create_invocation(request) + + # Assert + assert isinstance(response, CreateInvocationResponse) + mock_boto_client.create_invocation.assert_called_once() + + @pytest.mark.asyncio + async def test_list_invocation( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_list_invocation_response, + ): + # Arrange + mock_boto_client.list_invocations.return_value = ( + sample_list_invocation_response + ) + request = ListInvocationsRequest( + session_identifier=sample_session_id, max_results=1 + ) + + # Act + response = await mock_session_client.list_invocations(request) + + # Assert + assert isinstance(response, ListInvocationsResponse) + mock_boto_client.list_invocations.assert_called_once() + + class TestInvocationStep: + @pytest.mark.asyncio + async def test_put_invocation_step( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_invocation_id, + sample_invocation_step_id, + sample_timestamp, + sample_invocation_step_payload, + sample_put_invocation_step_response, + ): + # Arrange + mock_boto_client.put_invocation_step.return_value = ( + sample_put_invocation_step_response + ) + request = PutInvocationStepRequest( + session_identifier=sample_session_id, + invocation_identifier=sample_invocation_id, + invocation_step_id=sample_invocation_step_id, + invocation_step_time=sample_timestamp, + payload=sample_invocation_step_payload, + ) + + # Act + response = await mock_session_client.put_invocation_step(request) + + # Assert + assert isinstance(response, PutInvocationStepResponse) + mock_boto_client.put_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test_get_invocation_step( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_invocation_id, + sample_invocation_step_id, + sample_get_invocation_step_response, + ): + # Arrange + mock_boto_client.get_invocation_step.return_value = ( + sample_get_invocation_step_response + ) + request = GetInvocationStepRequest( + session_identifier=sample_session_id, + invocation_identifier=sample_invocation_id, + invocation_step_id=sample_invocation_step_id, + ) + + # Act + response = await mock_session_client.get_invocation_step(request) + + # Assert + assert isinstance(response, GetInvocationStepResponse) + mock_boto_client.get_invocation_step.assert_called_once() + + @pytest.mark.asyncio + async def test_list_invocation_steps( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_list_invocation_steps_response, + ): + # Arrange + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + request = ListInvocationStepsRequest( + session_identifier=sample_session_id, + max_results=1, + ) + + # Act + response = await mock_session_client.list_invocation_steps(request) + + # Assert + assert isinstance(response, ListInvocationStepsResponse) + mock_boto_client.list_invocation_steps.assert_called_once() + + @pytest.mark.asyncio + async def test_list_invocation_steps_by_invocation( + self, + mock_session_client, + mock_boto_client, + sample_session_id, + sample_invocation_id, + sample_list_invocation_steps_response, + ): + # Arrange + mock_boto_client.list_invocation_steps.return_value = ( + sample_list_invocation_steps_response + ) + request = ListInvocationStepsRequest( + session_identifier=sample_session_id, + invocation_identifier=sample_invocation_id, + max_results=1, + ) + + # Act + response = await mock_session_client.list_invocation_steps(request) + + # Assert + assert isinstance(response, ListInvocationStepsResponse) + mock_boto_client.list_invocation_steps.assert_called_once()