Skip to content

Commit f9fc79d

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add metadata to memories
feat: Support metadata filtering for memory retrieval feat: Support metadata merge strategies for memory generation PiperOrigin-RevId: 855380428
1 parent b814aab commit f9fc79d

File tree

6 files changed

+425
-2
lines changed

6 files changed

+425
-2
lines changed

tests/unit/vertexai/genai/replays/test_create_agent_engine_memory.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,26 @@ def test_create_memory_with_ttl(client):
2525
assert isinstance(agent_engine, types.AgentEngine)
2626
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
2727

28+
metadata = {
29+
"my_string_key": types.MemoryMetadataValue(string_value="my_string_value"),
30+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
31+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
32+
"my_timestamp_key": types.MemoryMetadataValue(
33+
timestamp_value=datetime.datetime(
34+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
35+
)
36+
),
37+
}
38+
2839
operation = client.agent_engines.memories.create(
2940
name=agent_engine.api_resource.name,
3041
fact="memory_fact",
3142
scope={"user_id": "123"},
32-
config=types.AgentEngineMemoryConfig(display_name="my_memory_fact", ttl="120s"),
43+
config=types.AgentEngineMemoryConfig(
44+
display_name="my_memory_fact",
45+
ttl="120s",
46+
metadata=metadata,
47+
),
3348
)
3449
assert isinstance(operation, types.AgentEngineMemoryOperation)
3550
assert operation.response.fact == "memory_fact"
@@ -42,6 +57,7 @@ def test_create_memory_with_ttl(client):
4257
<= operation.response.expire_time
4358
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
4459
)
60+
assert operation.response.metadata == metadata
4561
# Clean up resources.
4662
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
4763

@@ -51,7 +67,7 @@ def test_create_memory_with_expire_time(client):
5167
assert isinstance(agent_engine, types.AgentEngine)
5268
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
5369
expire_time = datetime.datetime(
54-
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
70+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
5571
)
5672

5773
operation = client.agent_engines.memories.create(

tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import datetime
1718
import pytest
1819

1920

@@ -145,6 +146,134 @@ def test_generate_memories_direct_memories_source(client):
145146
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
146147

147148

149+
def test_generate_memories_with_metadata(client):
150+
agent_engine = client.agent_engines.create()
151+
metadata = {
152+
"my_string_key": types.MemoryMetadataValue(string_value="my_string_value"),
153+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
154+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
155+
"my_timestamp_key": types.MemoryMetadataValue(
156+
timestamp_value=datetime.datetime(
157+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
158+
)
159+
),
160+
}
161+
# Reuse the same content and scope for all generation requests to ensure
162+
# that the same memory is updated.
163+
direct_memories_source = types.GenerateMemoriesRequestDirectMemoriesSource(
164+
direct_memories=[
165+
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
166+
fact="I am a software engineer."
167+
),
168+
]
169+
)
170+
scope = {"user_id": "test-user-id"}
171+
172+
operation = client.agent_engines.memories.generate(
173+
name=agent_engine.api_resource.name,
174+
scope=scope,
175+
direct_memories_source=direct_memories_source,
176+
config=types.GenerateAgentEngineMemoriesConfig(metadata=metadata),
177+
)
178+
assert len(operation.response.generated_memories) >= 1
179+
memory = client.agent_engines.memories.get(
180+
name=operation.response.generated_memories[0].memory.name
181+
)
182+
assert memory.metadata == metadata
183+
184+
# Overwrite the metadata.
185+
overwrite_metadata = {
186+
"my_string_key": types.MemoryMetadataValue(string_value="new_value"),
187+
}
188+
operation = client.agent_engines.memories.generate(
189+
name=agent_engine.api_resource.name,
190+
scope=scope,
191+
direct_memories_source=direct_memories_source,
192+
config=types.GenerateAgentEngineMemoriesConfig(
193+
metadata=overwrite_metadata,
194+
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.OVERWRITE,
195+
),
196+
)
197+
assert len(operation.response.generated_memories) >= 1
198+
assert (
199+
operation.response.generated_memories[0].action
200+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
201+
)
202+
memory = client.agent_engines.memories.get(
203+
name=operation.response.generated_memories[0].memory.name
204+
)
205+
assert memory.metadata == overwrite_metadata
206+
207+
# Merge the metadata.
208+
new_metadata = {
209+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
210+
}
211+
operation = client.agent_engines.memories.generate(
212+
name=agent_engine.api_resource.name,
213+
scope=scope,
214+
direct_memories_source=direct_memories_source,
215+
config=types.GenerateAgentEngineMemoriesConfig(
216+
metadata=new_metadata,
217+
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.MERGE,
218+
),
219+
)
220+
assert len(operation.response.generated_memories) >= 1
221+
assert (
222+
operation.response.generated_memories[0].action
223+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
224+
)
225+
memory = client.agent_engines.memories.get(
226+
name=operation.response.generated_memories[0].memory.name
227+
)
228+
assert memory.metadata == {**overwrite_metadata, **new_metadata}
229+
230+
# Restrict consolidation based on metadata values. For the first request,
231+
# there's no existing memories that match the metadata, so a new memory is
232+
# created.
233+
restricted_metadata = {
234+
"my_string_key": types.MemoryMetadataValue(string_value="new_value2"),
235+
}
236+
operation = client.agent_engines.memories.generate(
237+
name=agent_engine.api_resource.name,
238+
scope=scope,
239+
direct_memories_source=direct_memories_source,
240+
config=types.GenerateAgentEngineMemoriesConfig(
241+
metadata=restricted_metadata,
242+
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
243+
),
244+
)
245+
assert len(operation.response.generated_memories) == 1
246+
# Metadata doesn't match existing memory, so a new memory is created.
247+
assert (
248+
operation.response.generated_memories[0].action
249+
== types.GenerateMemoriesResponseGeneratedMemoryAction.CREATED
250+
)
251+
memory = client.agent_engines.memories.get(
252+
name=operation.response.generated_memories[0].memory.name
253+
)
254+
assert memory.metadata == restricted_metadata
255+
256+
# Send a second request where the metadata matches only one of the existing
257+
# memories.
258+
operation = client.agent_engines.memories.generate(
259+
name=agent_engine.api_resource.name,
260+
scope=scope,
261+
direct_memories_source=direct_memories_source,
262+
config=types.GenerateAgentEngineMemoriesConfig(
263+
metadata=restricted_metadata,
264+
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
265+
),
266+
)
267+
assert len(operation.response.generated_memories) == 1
268+
assert (
269+
operation.response.generated_memories[0].action
270+
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
271+
)
272+
assert operation.response.generated_memories[0].memory.name == memory.name
273+
274+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
275+
276+
148277
pytestmark = pytest_helper.setup(
149278
file=__file__,
150279
globals_for_file=globals(),

tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import datetime
1718
import pytest
1819

1920

@@ -115,6 +116,55 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
115116
agent_engine.delete(force=True)
116117

117118

119+
def test_retrieve_memories_with_metadata(client):
120+
agent_engine = client.agent_engines.create()
121+
metadata = {
122+
"my_string_key": types.MemoryMetadataValue(string_value="my_string_value"),
123+
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
124+
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
125+
"my_timestamp_key": types.MemoryMetadataValue(
126+
timestamp_value=datetime.datetime(
127+
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
128+
)
129+
),
130+
}
131+
scope = {"user_id": "123"}
132+
client.agent_engines.memories.create(
133+
name=agent_engine.api_resource.name,
134+
fact="memory_fact_1",
135+
scope=scope,
136+
)
137+
operation = client.agent_engines.memories.create(
138+
name=agent_engine.api_resource.name,
139+
fact="memory_fact_2",
140+
scope=scope,
141+
config={"metadata": metadata},
142+
)
143+
memory_name2 = operation.response.name
144+
145+
results = client.agent_engines.memories.retrieve(
146+
name=agent_engine.api_resource.name,
147+
scope=scope,
148+
config={
149+
"filter_groups": [
150+
{
151+
"filters": [
152+
{
153+
"key": "my_string_key",
154+
"value": {"string_value": "my_string_value"},
155+
}
156+
]
157+
}
158+
],
159+
},
160+
)
161+
assert len(results) == 1
162+
assert results[0].memory.name == memory_name2
163+
164+
# Clean up resources.
165+
agent_engine.delete(force=True)
166+
167+
118168
pytestmark = pytest_helper.setup(
119169
file=__file__,
120170
globals_for_file=globals(),

vertexai/_genai/memories.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _AgentEngineMemoryConfig_to_vertex(
7777
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
7878
)
7979

80+
if getv(from_object, ["metadata"]) is not None:
81+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
82+
8083
return to_object
8184

8285

@@ -153,6 +156,16 @@ def _GenerateAgentEngineMemoriesConfig_to_vertex(
153156
getv(from_object, ["disable_memory_revisions"]),
154157
)
155158

159+
if getv(from_object, ["metadata"]) is not None:
160+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
161+
162+
if getv(from_object, ["metadata_merge_strategy"]) is not None:
163+
setv(
164+
parent_object,
165+
["metadataMergeStrategy"],
166+
getv(from_object, ["metadata_merge_strategy"]),
167+
)
168+
156169
return to_object
157170

158171

@@ -316,6 +329,13 @@ def _RetrieveAgentEngineMemoriesConfig_to_vertex(
316329
if getv(from_object, ["filter"]) is not None:
317330
setv(parent_object, ["filter"], getv(from_object, ["filter"]))
318331

332+
if getv(from_object, ["filter_groups"]) is not None:
333+
setv(
334+
parent_object,
335+
["filterGroups"],
336+
[item for item in getv(from_object, ["filter_groups"])],
337+
)
338+
319339
return to_object
320340

321341

@@ -413,6 +433,9 @@ def _UpdateAgentEngineMemoryConfig_to_vertex(
413433
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
414434
)
415435

436+
if getv(from_object, ["metadata"]) is not None:
437+
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))
438+
416439
if getv(from_object, ["update_mask"]) is not None:
417440
setv(
418441
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])

vertexai/_genai/types/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,17 @@
578578
from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict
579579
from .common import MemoryBankCustomizationConfigMemoryTopicOrDict
580580
from .common import MemoryBankCustomizationConfigOrDict
581+
from .common import MemoryConjunctionFilter
582+
from .common import MemoryConjunctionFilterDict
583+
from .common import MemoryConjunctionFilterOrDict
581584
from .common import MemoryDict
585+
from .common import MemoryFilter
586+
from .common import MemoryFilterDict
587+
from .common import MemoryFilterOrDict
588+
from .common import MemoryMetadataMergeStrategy
589+
from .common import MemoryMetadataValue
590+
from .common import MemoryMetadataValueDict
591+
from .common import MemoryMetadataValueOrDict
582592
from .common import MemoryOrDict
583593
from .common import MemoryRevision
584594
from .common import MemoryRevisionDict
@@ -613,6 +623,7 @@
613623
from .common import ObservabilityEvalCase
614624
from .common import ObservabilityEvalCaseDict
615625
from .common import ObservabilityEvalCaseOrDict
626+
from .common import Operator
616627
from .common import OptimizeConfig
617628
from .common import OptimizeConfigDict
618629
from .common import OptimizeConfigOrDict
@@ -1523,6 +1534,15 @@
15231534
"RetrieveMemoriesRequestSimpleRetrievalParams",
15241535
"RetrieveMemoriesRequestSimpleRetrievalParamsDict",
15251536
"RetrieveMemoriesRequestSimpleRetrievalParamsOrDict",
1537+
"MemoryMetadataValue",
1538+
"MemoryMetadataValueDict",
1539+
"MemoryMetadataValueOrDict",
1540+
"MemoryFilter",
1541+
"MemoryFilterDict",
1542+
"MemoryFilterOrDict",
1543+
"MemoryConjunctionFilter",
1544+
"MemoryConjunctionFilterDict",
1545+
"MemoryConjunctionFilterOrDict",
15261546
"RetrieveAgentEngineMemoriesConfig",
15271547
"RetrieveAgentEngineMemoriesConfigDict",
15281548
"RetrieveAgentEngineMemoriesConfigOrDict",
@@ -1909,6 +1929,7 @@
19091929
"IdentityType",
19101930
"AgentServerMode",
19111931
"ManagedTopicEnum",
1932+
"Operator",
19121933
"Language",
19131934
"MachineConfig",
19141935
"State",
@@ -1917,6 +1938,7 @@
19171938
"RubricContentType",
19181939
"EvaluationRunState",
19191940
"OptimizeTarget",
1941+
"MemoryMetadataMergeStrategy",
19201942
"GenerateMemoriesResponseGeneratedMemoryAction",
19211943
"PromptOptimizerMethod",
19221944
"PromptData",

0 commit comments

Comments
 (0)