Skip to content

Commit 2595824

Browse files
guillaumeblaquierecopybara-github
authored andcommitted
feat: add endpoint to generate memory from session
Merge #2900 In relation with #2416 COPYBARA_INTEGRATE_REVIEW=#2900 from guillaumeblaquiere:add-session-to-memory 0507de4 PiperOrigin-RevId: 808658162
1 parent 6b49391 commit 2595824

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from fastapi import FastAPI
3232
from fastapi import HTTPException
3333
from fastapi import Query
34+
from fastapi import Response
3435
from fastapi.middleware.cors import CORSMiddleware
3536
from fastapi.responses import RedirectResponse
3637
from fastapi.responses import StreamingResponse
@@ -210,6 +211,13 @@ class RunEvalRequest(common.BaseModel):
210211
eval_metrics: list[EvalMetric]
211212

212213

214+
class UpdateMemoryRequest(common.BaseModel):
215+
"""Request to add a session to the memory service."""
216+
217+
session_id: str
218+
"""The ID of the session to add to memory."""
219+
220+
213221
class RunEvalResult(common.BaseModel):
214222
eval_set_file: str
215223
eval_set_id: str
@@ -1144,6 +1152,41 @@ async def delete_artifact(
11441152
filename=artifact_name,
11451153
)
11461154

1155+
@app.patch("/apps/{app_name}/users/{user_id}/memory")
1156+
async def patch_memory(
1157+
app_name: str, user_id: str, update_memory_request: UpdateMemoryRequest
1158+
) -> None:
1159+
"""Adds all events from a given session to the memory service.
1160+
1161+
Args:
1162+
app_name: The name of the application.
1163+
user_id: The ID of the user.
1164+
update_memory_request: The memory request for the update
1165+
1166+
Raises:
1167+
HTTPException: If the memory service is not configured or the request is invalid.
1168+
"""
1169+
if not self.memory_service:
1170+
raise HTTPException(
1171+
status_code=400, detail="Memory service is not configured."
1172+
)
1173+
if (
1174+
update_memory_request is None
1175+
or update_memory_request.session_id is None
1176+
):
1177+
raise HTTPException(
1178+
status_code=400, detail="Update memory request is invalid."
1179+
)
1180+
1181+
session = await self.session_service.get_session(
1182+
app_name=app_name,
1183+
user_id=user_id,
1184+
session_id=update_memory_request.session_id,
1185+
)
1186+
if not session:
1187+
raise HTTPException(status_code=404, detail="Session not found")
1188+
await self.memory_service.add_session_to_memory(session)
1189+
11471190
@app.post("/run", response_model_exclude_none=True)
11481191
async def run_agent(req: RunAgentRequest) -> list[Event]:
11491192
session = await self.session_service.get_session(

tests/unittests/cli/test_fast_api.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import time
2323
from typing import Any
2424
from typing import Optional
25+
from unittest.mock import AsyncMock
2526
from unittest.mock import MagicMock
2627
from unittest.mock import patch
2728

@@ -344,7 +345,7 @@ async def delete_artifact(self, app_name, user_id, session_id, filename):
344345
@pytest.fixture
345346
def mock_memory_service():
346347
"""Create a mock memory service."""
347-
return MagicMock()
348+
return AsyncMock()
348349

349350

350351
@pytest.fixture
@@ -939,5 +940,18 @@ def test_a2a_disabled_by_default(test_app):
939940
logger.info("A2A disabled by default test passed")
940941

941942

943+
def test_patch_memory(test_app, create_test_session, mock_memory_service):
944+
"""Test adding a session to memory."""
945+
info = create_test_session
946+
url = f"/apps/{info['app_name']}/users/{info['user_id']}/memory"
947+
payload = {"session_id": info["session_id"]}
948+
response = test_app.patch(url, json=payload)
949+
950+
# Verify the response
951+
assert response.status_code == 200
952+
mock_memory_service.add_session_to_memory.assert_called_once()
953+
logger.info("Add session to memory test completed successfully")
954+
955+
942956
if __name__ == "__main__":
943957
pytest.main(["-xvs", __file__])

0 commit comments

Comments
 (0)