diff --git a/pyproject.toml b/pyproject.toml index d00951f8f..8a8c334b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "python-dateutil>=2.9.0.post0", # For Vertext AI Session Service "python-dotenv>=1.0.0", # To manage environment variables "requests>=2.32.4", + "sqlalchemy-spanner>=1.14.0", # Spanner database session service "sqlalchemy>=2.0", # SQL database ORM "starlette>=0.46.2", # For FastAPI CLI "tenacity>=8.0.0", # For Retry management diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 2d88007bb..1d58b4969 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -18,10 +18,12 @@ from datetime import timezone import json import logging +import pickle from typing import Any from typing import Optional import uuid +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType from sqlalchemy import Boolean from sqlalchemy import delete from sqlalchemy import Dialect @@ -104,6 +106,31 @@ def load_dialect_impl(self, dialect): return self.impl +class DynamicPickleType(TypeDecorator): + """Represents a type that can be pickled.""" + + impl = PickleType + + def load_dialect_impl(self, dialect): + if dialect.name == "spanner+spanner": + return dialect.type_descriptor(SpannerPickleType) + return self.impl + + def process_bind_param(self, value, dialect): + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" + if value is not None: + if dialect.name == "spanner+spanner": + return pickle.dumps(value) + return value + + def process_result_value(self, value, dialect): + """Ensures the raw bytes from the database are unpickled back into a Python object.""" + if value is not None: + if dialect.name == "spanner+spanner": + return pickle.loads(value) + return value + + class Base(DeclarativeBase): """Base class for database tables.""" @@ -209,7 +236,7 @@ class StorageEvent(Base): PreciseTimestamp, default=func.now() ) content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True) - actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column( Text, nullable=True