diff --git a/src/google/adk/sessions/_session_copy_utils.py b/src/google/adk/sessions/_session_copy_utils.py new file mode 100644 index 0000000000..e07582c729 --- /dev/null +++ b/src/google/adk/sessions/_session_copy_utils.py @@ -0,0 +1,94 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for safely copying session objects that may contain non-serializable objects.""" + +from __future__ import annotations + +import copy +import inspect +import logging +from typing import Any + +logger = logging.getLogger('google_adk.' + __name__) + + +def _is_async_generator(obj: Any) -> bool: + """Check if an object is an async generator.""" + return inspect.isasyncgen(obj) + + +def _filter_non_serializable_objects(obj: Any, path: str = "root") -> Any: + """Recursively filter out non-serializable objects from a data structure. + + Args: + obj: The object to filter + path: The current path in the object tree (for logging) + + Returns: + A copy of the object with non-serializable objects removed + """ + if _is_async_generator(obj): + logger.warning( + f"Removing async generator from session state at {path}. " + "Async generators cannot be persisted in session state." + ) + return None + + if isinstance(obj, dict): + filtered_dict = {} + for key, value in obj.items(): + filtered_value = _filter_non_serializable_objects(value, f"{path}.{key}") + if filtered_value is not None: + filtered_dict[key] = filtered_value + return filtered_dict + + elif isinstance(obj, (list, tuple)): + filtered_items = [] + for i, item in enumerate(obj): + filtered_item = _filter_non_serializable_objects(item, f"{path}[{i}]") + if filtered_item is not None: + filtered_items.append(filtered_item) + return type(obj)(filtered_items) + + # For other types, assume they're serializable + return obj + + +def safe_deepcopy_session(session): + """Safely deepcopy a session object, filtering out non-serializable objects. + + This function creates a deep copy of a session while filtering out objects + that cannot be pickled, such as async generators. + + Args: + session: The session object to copy + + Returns: + A deep copy of the session with non-serializable objects filtered out + """ + # Create a shallow copy first + session_copy = copy.copy(session) + + # Deep copy the state while filtering non-serializable objects + if hasattr(session_copy, 'state') and session_copy.state: + session_copy.state = _filter_non_serializable_objects(session_copy.state, "state") + # Now we can safely deepcopy the filtered state + session_copy.state = copy.deepcopy(session_copy.state) + + # Deep copy other attributes that should be safe + if hasattr(session_copy, 'events'): + session_copy.events = copy.deepcopy(session.events) + + return session_copy \ No newline at end of file diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index bbb480ae45..2e7516205b 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -28,6 +28,8 @@ from .base_session_service import ListSessionsResponse from .session import Session from .state import State +from ._session_copy_utils import safe_deepcopy_session +from ._session_copy_utils import _filter_non_serializable_objects logger = logging.getLogger('google_adk.' + __name__) @@ -93,11 +95,15 @@ def _create_session_impl( if session_id and session_id.strip() else str(uuid.uuid4()) ) + + # Filter out non-serializable objects from the state before creating the session + filtered_state = _filter_non_serializable_objects(state or {}, "initial_state") + session = Session( app_name=app_name, user_id=user_id, id=session_id, - state=state or {}, + state=filtered_state, last_update_time=time.time(), ) @@ -107,7 +113,7 @@ def _create_session_impl( self.sessions[app_name][user_id] = {} self.sessions[app_name][user_id][session_id] = session - copied_session = copy.deepcopy(session) + copied_session = safe_deepcopy_session(session) return self._merge_state(app_name, user_id, copied_session) @override @@ -158,7 +164,7 @@ def _get_session_impl( return None session = self.sessions[app_name][user_id].get(session_id) - copied_session = copy.deepcopy(session) + copied_session = safe_deepcopy_session(session) if config: if config.num_recent_events: @@ -222,7 +228,7 @@ def _list_sessions_impl( sessions_without_events = [] for session in self.sessions[app_name][user_id].values(): - copied_session = copy.deepcopy(session) + copied_session = safe_deepcopy_session(session) copied_session.events = [] copied_session = self._merge_state(app_name, user_id, copied_session) sessions_without_events.append(copied_session) diff --git a/tests/unittests/sessions/test_async_generator_session_fix.py b/tests/unittests/sessions/test_async_generator_session_fix.py new file mode 100644 index 0000000000..cabb4009e1 --- /dev/null +++ b/tests/unittests/sessions/test_async_generator_session_fix.py @@ -0,0 +1,160 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for async generator handling in session services. + +This module tests the fix for issue #1862 where async generators in session +state would cause pickle errors during deepcopy operations. +""" + +import asyncio +import pytest +from typing import AsyncGenerator + +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.sessions._session_copy_utils import ( + _filter_non_serializable_objects, + _is_async_generator, + safe_deepcopy_session, +) + + +async def test_async_generator() -> AsyncGenerator[str, None]: + """A test async generator function.""" + yield "test_message_1" + yield "test_message_2" + + +class TestAsyncGeneratorSessionHandling: + """Test class for async generator handling in sessions.""" + + def test_is_async_generator_detection(self): + """Test that async generators are correctly detected.""" + async def regular_async_func(): + return "not a generator" + + def regular_func(): + return "regular function" + + # Test with actual async generator + async_gen = test_async_generator() + assert _is_async_generator(async_gen) is True + + # Test with non-generators + assert _is_async_generator(regular_func) is False + assert _is_async_generator("string") is False + assert _is_async_generator(123) is False + assert _is_async_generator([1, 2, 3]) is False + assert _is_async_generator({"key": "value"}) is False + + # Clean up + asyncio.run(async_gen.aclose()) + + def test_filter_non_serializable_objects(self): + """Test filtering of async generators from nested data structures.""" + async_gen = test_async_generator() + + # Test simple case + state = {"async_tool": async_gen, "normal_data": "test_value"} + filtered = _filter_non_serializable_objects(state) + + assert "normal_data" in filtered + assert filtered["normal_data"] == "test_value" + assert "async_tool" not in filtered + + # Test nested structure + nested_state = { + "level1": { + "level2": { + "async_gen": async_gen, + "normal": "value" + }, + "other": "data" + }, + "top_level": "value" + } + + filtered_nested = _filter_non_serializable_objects(nested_state) + assert filtered_nested["level1"]["level2"]["normal"] == "value" + assert "async_gen" not in filtered_nested["level1"]["level2"] + assert filtered_nested["level1"]["other"] == "data" + assert filtered_nested["top_level"] == "value" + + # Test list with async generator + list_state = {"tools": [async_gen, "normal_tool"]} + filtered_list = _filter_non_serializable_objects(list_state) + assert len(filtered_list["tools"]) == 1 + assert filtered_list["tools"][0] == "normal_tool" + + # Clean up + asyncio.run(async_gen.aclose()) + + @pytest.mark.asyncio + async def test_session_creation_with_async_generator(self): + """Test that session creation works with async generators in state.""" + session_service = InMemorySessionService() + async_gen = test_async_generator() + + # This should not raise an exception + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + state={ + "streaming_tool": async_gen, + "normal_data": "test_value" + } + ) + + # The async generator should be filtered out + assert "streaming_tool" not in session.state + assert "normal_data" in session.state + assert session.state["normal_data"] == "test_value" + + # Clean up + await async_gen.aclose() + + @pytest.mark.asyncio + async def test_session_operations_with_filtered_state(self): + """Test that all session operations work after filtering.""" + session_service = InMemorySessionService() + + # Create session with normal state + session = await session_service.create_session( + app_name="test_app", + user_id="test_user", + state={"normal_data": "test_value"} + ) + + # Test get_session + retrieved_session = await session_service.get_session( + app_name="test_app", + user_id="test_user", + session_id=session.id + ) + assert retrieved_session is not None + assert retrieved_session.state["normal_data"] == "test_value" + + # Test list_sessions + sessions_response = await session_service.list_sessions( + app_name="test_app", + user_id="test_user" + ) + assert len(sessions_response.sessions) == 1 + assert sessions_response.sessions[0].id == session.id + + def test_safe_deepcopy_session(self): + """Test the safe_deepcopy_session function.""" + # This test would require creating a mock session object + # For now, we test that the function exists and can be imported + assert callable(safe_deepcopy_session) \ No newline at end of file