12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import re
16
- from typing import Any
15
+ from datetime import datetime
17
16
from unittest import mock
18
17
19
18
from google .adk .events .event import Event
70
69
)
71
70
72
71
73
- RETRIEVE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:retrieve$'
74
- GENERATE_MEMORIES_REGEX = r'^reasoningEngines/([^/]+)/memories:generate$'
75
-
76
-
77
- class MockApiClient :
78
- """Mocks the API Client."""
79
-
80
- def __init__ (self ) -> None :
81
- """Initializes MockClient."""
82
- self .async_request = mock .AsyncMock ()
83
- self .async_request .side_effect = self ._mock_async_request
84
-
85
- async def _mock_async_request (
86
- self , http_method : str , path : str , request_dict : dict [str , Any ]
87
- ):
88
- """Mocks the API Client request method."""
89
- if http_method == 'POST' :
90
- if re .match (GENERATE_MEMORIES_REGEX , path ):
91
- return {}
92
- elif re .match (RETRIEVE_MEMORIES_REGEX , path ):
93
- if (
94
- request_dict .get ('scope' , None )
95
- and request_dict ['scope' ].get ('app_name' , None ) == MOCK_APP_NAME
96
- ):
97
- return {
98
- 'retrievedMemories' : [
99
- {
100
- 'memory' : {
101
- 'fact' : 'test_content' ,
102
- },
103
- 'updateTime' : '2024-12-12T12:12:12.123456Z' ,
104
- },
105
- ],
106
- }
107
- else :
108
- return {'retrievedMemories' : []}
109
- else :
110
- raise ValueError (f'Unsupported path: { path } ' )
111
- else :
112
- raise ValueError (f'Unsupported http method: { http_method } ' )
113
-
114
-
115
72
def mock_vertex_ai_memory_bank_service ():
116
73
"""Creates a mock Vertex AI Memory Bank service for testing."""
117
74
return VertexAiMemoryBankService (
@@ -122,67 +79,86 @@ def mock_vertex_ai_memory_bank_service():
122
79
123
80
124
81
@pytest .fixture
125
- def mock_get_api_client ():
126
- api_client = MockApiClient ()
82
+ def mock_vertexai_client ():
127
83
with mock .patch (
128
- 'google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService._get_api_client' ,
129
- return_value = api_client ,
130
- ):
131
- yield api_client
84
+ 'google.adk.memory.vertex_ai_memory_bank_service.vertexai.Client'
85
+ ) as mock_client_constructor :
86
+ mock_client = mock .MagicMock ()
87
+ mock_client .agent_engines .memories .generate = mock .MagicMock ()
88
+ mock_client .agent_engines .memories .retrieve = mock .MagicMock ()
89
+ mock_client_constructor .return_value = mock_client
90
+ yield mock_client
132
91
133
92
134
93
@pytest .mark .asyncio
135
- @pytest .mark .usefixtures ('mock_get_api_client' )
136
- async def test_add_session_to_memory (mock_get_api_client ):
94
+ async def test_add_session_to_memory (mock_vertexai_client ):
137
95
memory_service = mock_vertex_ai_memory_bank_service ()
138
96
await memory_service .add_session_to_memory (MOCK_SESSION )
139
97
140
- mock_get_api_client .async_request .assert_awaited_once_with (
141
- http_method = 'POST' ,
142
- path = 'reasoningEngines/123/memories:generate' ,
143
- request_dict = {
144
- 'direct_contents_source' : {
145
- 'events' : [
146
- {
147
- 'content' : {
148
- 'parts' : [
149
- {'text' : 'test_content' },
150
- ],
151
- },
152
- },
153
- ],
154
- },
155
- 'scope' : {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
98
+ mock_vertexai_client .agent_engines .memories .generate .assert_called_once_with (
99
+ name = 'reasoningEngines/123' ,
100
+ direct_contents_source = {
101
+ 'events' : [
102
+ {
103
+ 'content' : {
104
+ 'parts' : [{'text' : 'test_content' }],
105
+ }
106
+ }
107
+ ]
156
108
},
109
+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
110
+ config = {'wait_for_completion' : False },
157
111
)
158
112
159
113
160
114
@pytest .mark .asyncio
161
- @pytest .mark .usefixtures ('mock_get_api_client' )
162
- async def test_add_empty_session_to_memory (mock_get_api_client ):
115
+ async def test_add_empty_session_to_memory (mock_vertexai_client ):
163
116
memory_service = mock_vertex_ai_memory_bank_service ()
164
117
await memory_service .add_session_to_memory (MOCK_SESSION_WITH_EMPTY_EVENTS )
165
118
166
- mock_get_api_client . async_request .assert_not_called ()
119
+ mock_vertexai_client . agent_engines . memories . generate .assert_not_called ()
167
120
168
121
169
122
@pytest .mark .asyncio
170
- @pytest .mark .usefixtures ('mock_get_api_client' )
171
- async def test_search_memory (mock_get_api_client ):
123
+ async def test_search_memory (mock_vertexai_client ):
124
+ retrieved_memory = mock .MagicMock ()
125
+ retrieved_memory .memory .fact = 'test_content'
126
+ retrieved_memory .memory .update_time = datetime (
127
+ 2024 , 12 , 12 , 12 , 12 , 12 , 123456
128
+ )
129
+
130
+ mock_vertexai_client .agent_engines .memories .retrieve .return_value = [
131
+ retrieved_memory
132
+ ]
172
133
memory_service = mock_vertex_ai_memory_bank_service ()
173
134
174
135
result = await memory_service .search_memory (
175
136
app_name = MOCK_APP_NAME , user_id = MOCK_USER_ID , query = 'query'
176
137
)
177
138
178
- mock_get_api_client .async_request .assert_awaited_once_with (
179
- http_method = 'POST' ,
180
- path = 'reasoningEngines/123/memories:retrieve' ,
181
- request_dict = {
182
- 'scope' : {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
183
- 'similarity_search_params' : {'search_query' : 'query' },
184
- },
139
+ mock_vertexai_client .agent_engines .memories .retrieve .assert_called_once_with (
140
+ name = 'reasoningEngines/123' ,
141
+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
142
+ similarity_search_params = {'search_query' : 'query' },
185
143
)
186
144
187
145
assert len (result .memories ) == 1
188
146
assert result .memories [0 ].content .parts [0 ].text == 'test_content'
147
+
148
+
149
+ @pytest .mark .asyncio
150
+ async def test_search_memory_empty_results (mock_vertexai_client ):
151
+ mock_vertexai_client .agent_engines .memories .retrieve .return_value = []
152
+ memory_service = mock_vertex_ai_memory_bank_service ()
153
+
154
+ result = await memory_service .search_memory (
155
+ app_name = MOCK_APP_NAME , user_id = MOCK_USER_ID , query = 'query'
156
+ )
157
+
158
+ mock_vertexai_client .agent_engines .memories .retrieve .assert_called_once_with (
159
+ name = 'reasoningEngines/123' ,
160
+ scope = {'app_name' : MOCK_APP_NAME , 'user_id' : MOCK_USER_ID },
161
+ similarity_search_params = {'search_query' : 'query' },
162
+ )
163
+
164
+ assert len (result .memories ) == 0
0 commit comments