Skip to content

Commit 83fd045

Browse files
DeanChensjcopybara-github
authored andcommitted
feat: Migrate VertexAiMemoryBankService to use Agent Engine SDK
PiperOrigin-RevId: 813104746
1 parent ce9c39f commit 83fd045

File tree

3 files changed

+86
-134
lines changed

3 files changed

+86
-134
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ dependencies = [
3232
"click>=8.1.8, <9.0.0", # For CLI tools
3333
"fastapi>=0.115.0, <1.0.0", # FastAPI framework
3434
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
35-
"google-cloud-aiplatform[agent_engines]>=1.95.1, <2.0.0", # For VertexAI integrations, e.g. example store.
35+
"google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0",# For VertexAI integrations, e.g. example store.
3636
"google-cloud-bigtable>=2.32.0", # For Bigtable database
3737
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
3838
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database

src/google/adk/memory/vertex_ai_memory_bank_service.py

Lines changed: 29 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@
1414

1515
from __future__ import annotations
1616

17-
import json
1817
import logging
19-
from typing import Any
20-
from typing import Dict
2118
from typing import Optional
2219
from typing import TYPE_CHECKING
2320

24-
from google.genai import Client
2521
from google.genai import types
2622
from typing_extensions import override
23+
import vertexai
2724

2825
from .base_memory_service import BaseMemoryService
2926
from .base_memory_service import SearchMemoryResponse
@@ -59,8 +56,6 @@ def __init__(
5956

6057
@override
6158
async def add_session_to_memory(self, session: Session):
62-
api_client = self._get_api_client()
63-
6459
if not self._agent_engine_id:
6560
raise ValueError('Agent Engine ID is required for Memory Bank.')
6661

@@ -72,62 +67,53 @@ async def add_session_to_memory(self, session: Session):
7267
events.append({
7368
'content': event.content.model_dump(exclude_none=True, mode='json')
7469
})
75-
request_dict = {
76-
'direct_contents_source': {
77-
'events': events,
78-
},
79-
'scope': {
80-
'app_name': session.app_name,
81-
'user_id': session.user_id,
82-
},
83-
}
84-
8570
if events:
86-
api_response = await api_client.async_request(
87-
http_method='POST',
88-
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
89-
request_dict=request_dict,
71+
client = self._get_api_client()
72+
operation = client.agent_engines.memories.generate(
73+
name='reasoningEngines/' + self._agent_engine_id,
74+
direct_contents_source={'events': events},
75+
scope={
76+
'app_name': session.app_name,
77+
'user_id': session.user_id,
78+
},
79+
config={'wait_for_completion': False},
9080
)
9181
logger.info('Generate memory response received.')
92-
logger.debug('Generate memory response: %s', api_response)
82+
logger.debug('Generate memory response: %s', operation)
9383
else:
9484
logger.info('No events to add to memory.')
9585

9686
@override
9787
async def search_memory(self, *, app_name: str, user_id: str, query: str):
98-
api_client = self._get_api_client()
99-
100-
api_response = await api_client.async_request(
101-
http_method='POST',
102-
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
103-
request_dict={
104-
'scope': {
105-
'app_name': app_name,
106-
'user_id': user_id,
107-
},
108-
'similarity_search_params': {
109-
'search_query': query,
110-
},
88+
if not self._agent_engine_id:
89+
raise ValueError('Agent Engine ID is required for Memory Bank.')
90+
91+
client = self._get_api_client()
92+
retrieved_memories_iterator = client.agent_engines.memories.retrieve(
93+
name='reasoningEngines/' + self._agent_engine_id,
94+
scope={
95+
'app_name': app_name,
96+
'user_id': user_id,
97+
},
98+
similarity_search_params={
99+
'search_query': query,
111100
},
112101
)
113-
api_response = _convert_api_response(api_response)
114-
logger.info('Search memory response received.')
115-
logger.debug('Search memory response: %s', api_response)
116102

117-
if not api_response or not api_response.get('retrievedMemories', None):
118-
return SearchMemoryResponse()
103+
logger.info('Search memory response received.')
119104

120105
memory_events = []
121-
for memory in api_response.get('retrievedMemories', []):
106+
for retrieved_memory in retrieved_memories_iterator:
122107
# TODO: add more complex error handling
108+
logger.debug('Retrieved memory: %s', retrieved_memory)
123109
memory_events.append(
124110
MemoryEntry(
125111
author='user',
126112
content=types.Content(
127-
parts=[types.Part(text=memory.get('memory').get('fact'))],
113+
parts=[types.Part(text=retrieved_memory.memory.fact)],
128114
role='user',
129115
),
130-
timestamp=memory.get('updateTime'),
116+
timestamp=retrieved_memory.memory.update_time.isoformat(),
131117
)
132118
)
133119
return SearchMemoryResponse(memories=memory_events)
@@ -141,17 +127,7 @@ def _get_api_client(self):
141127
Returns:
142128
An API client for the given project and location.
143129
"""
144-
client = Client(
145-
vertexai=True, project=self._project, location=self._location
146-
)
147-
return client._api_client
148-
149-
150-
def _convert_api_response(api_response) -> Dict[str, Any]:
151-
"""Converts the API response to a JSON object based on the type."""
152-
if hasattr(api_response, 'body'):
153-
return json.loads(api_response.body)
154-
return api_response
130+
return vertexai.Client(project=self._project, location=self._location)
155131

156132

157133
def _should_filter_out_event(content: types.Content) -> bool:

tests/unittests/memory/test_vertex_ai_memory_bank_service.py

Lines changed: 56 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import re
16-
from typing import Any
15+
from datetime import datetime
1716
from unittest import mock
1817

1918
from google.adk.events.event import Event
@@ -70,48 +69,6 @@
7069
)
7170

7271

73-
RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
74-
GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
75-
76-
77-
class MockApiClient:
78-
"""Mocks the API Client."""
79-
80-
def __init__(self) -> None:
81-
"""Initializes MockClient."""
82-
self.async_request = mock.AsyncMock()
83-
self.async_request.side_effect = self._mock_async_request
84-
85-
async def _mock_async_request(
86-
self, http_method: str, path: str, request_dict: dict[str, Any]
87-
):
88-
"""Mocks the API Client request method."""
89-
if http_method == 'POST':
90-
if re.match(GENERATE_MEMORIES_REGEX, path):
91-
return {}
92-
elif re.match(RETRIEVE_MEMORIES_REGEX, path):
93-
if (
94-
request_dict.get('scope', None)
95-
and request_dict['scope'].get('app_name', None) == MOCK_APP_NAME
96-
):
97-
return {
98-
'retrievedMemories': [
99-
{
100-
'memory': {
101-
'fact': 'test_content',
102-
},
103-
'updateTime': '2024-12-12T12:12:12.123456Z',
104-
},
105-
],
106-
}
107-
else:
108-
return {'retrievedMemories': []}
109-
else:
110-
raise ValueError(f'Unsupported path: {path}')
111-
else:
112-
raise ValueError(f'Unsupported http method: {http_method}')
113-
114-
11572
def mock_vertex_ai_memory_bank_service():
11673
"""Creates a mock Vertex AI Memory Bank service for testing."""
11774
return VertexAiMemoryBankService(
@@ -122,67 +79,86 @@ def mock_vertex_ai_memory_bank_service():
12279

12380

12481
@pytest.fixture
125-
def mock_get_api_client():
126-
api_client = MockApiClient()
82+
def mock_vertexai_client():
12783
with mock.patch(
128-
'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client',
129-
return_value=api_client,
130-
):
131-
yield api_client
84+
'google.adk.memory.vertex_ai_memory_bank_service.vertexai.Client'
85+
) as mock_client_constructor:
86+
mock_client = mock.MagicMock()
87+
mock_client.agent_engines.memories.generate = mock.MagicMock()
88+
mock_client.agent_engines.memories.retrieve = mock.MagicMock()
89+
mock_client_constructor.return_value = mock_client
90+
yield mock_client
13291

13392

13493
@pytest.mark.asyncio
135-
@pytest.mark.usefixtures('mock_get_api_client')
136-
async def test_add_session_to_memory(mock_get_api_client):
94+
async def test_add_session_to_memory(mock_vertexai_client):
13795
memory_service = mock_vertex_ai_memory_bank_service()
13896
await memory_service.add_session_to_memory(MOCK_SESSION)
13997

140-
mock_get_api_client.async_request.assert_awaited_once_with(
141-
http_method='POST',
142-
path='reasoningEngines/123/memories:generate',
143-
request_dict={
144-
'direct_contents_source': {
145-
'events': [
146-
{
147-
'content': {
148-
'parts': [
149-
{'text': 'test_content'},
150-
],
151-
},
152-
},
153-
],
154-
},
155-
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
98+
mock_vertexai_client.agent_engines.memories.generate.assert_called_once_with(
99+
name='reasoningEngines/123',
100+
direct_contents_source={
101+
'events': [
102+
{
103+
'content': {
104+
'parts': [{'text': 'test_content'}],
105+
}
106+
}
107+
]
156108
},
109+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
110+
config={'wait_for_completion': False},
157111
)
158112

159113

160114
@pytest.mark.asyncio
161-
@pytest.mark.usefixtures('mock_get_api_client')
162-
async def test_add_empty_session_to_memory(mock_get_api_client):
115+
async def test_add_empty_session_to_memory(mock_vertexai_client):
163116
memory_service = mock_vertex_ai_memory_bank_service()
164117
await memory_service.add_session_to_memory(MOCK_SESSION_WITH_EMPTY_EVENTS)
165118

166-
mock_get_api_client.async_request.assert_not_called()
119+
mock_vertexai_client.agent_engines.memories.generate.assert_not_called()
167120

168121

169122
@pytest.mark.asyncio
170-
@pytest.mark.usefixtures('mock_get_api_client')
171-
async def test_search_memory(mock_get_api_client):
123+
async def test_search_memory(mock_vertexai_client):
124+
retrieved_memory = mock.MagicMock()
125+
retrieved_memory.memory.fact = 'test_content'
126+
retrieved_memory.memory.update_time = datetime(
127+
2024, 12, 12, 12, 12, 12, 123456
128+
)
129+
130+
mock_vertexai_client.agent_engines.memories.retrieve.return_value = [
131+
retrieved_memory
132+
]
172133
memory_service = mock_vertex_ai_memory_bank_service()
173134

174135
result = await memory_service.search_memory(
175136
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
176137
)
177138

178-
mock_get_api_client.async_request.assert_awaited_once_with(
179-
http_method='POST',
180-
path='reasoningEngines/123/memories:retrieve',
181-
request_dict={
182-
'scope': {'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
183-
'similarity_search_params': {'search_query': 'query'},
184-
},
139+
mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with(
140+
name='reasoningEngines/123',
141+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
142+
similarity_search_params={'search_query': 'query'},
185143
)
186144

187145
assert len(result.memories) == 1
188146
assert result.memories[0].content.parts[0].text == 'test_content'
147+
148+
149+
@pytest.mark.asyncio
150+
async def test_search_memory_empty_results(mock_vertexai_client):
151+
mock_vertexai_client.agent_engines.memories.retrieve.return_value = []
152+
memory_service = mock_vertex_ai_memory_bank_service()
153+
154+
result = await memory_service.search_memory(
155+
app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='query'
156+
)
157+
158+
mock_vertexai_client.agent_engines.memories.retrieve.assert_called_once_with(
159+
name='reasoningEngines/123',
160+
scope={'app_name': MOCK_APP_NAME, 'user_id': MOCK_USER_ID},
161+
similarity_search_params={'search_query': 'query'},
162+
)
163+
164+
assert len(result.memories) == 0

0 commit comments

Comments
 (0)