Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,28 @@ def test_create_memory_with_ttl(client):
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)

metadata = {
"my_string_key": types.MemoryMetadataValue(
string_value="my_string_value"
),
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
"my_timestamp_key": types.MemoryMetadataValue(
timestamp_value=datetime.datetime(
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)
),
}

operation = client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact",
scope={"user_id": "123"},
config=types.AgentEngineMemoryConfig(display_name="my_memory_fact", ttl="120s"),
config=types.AgentEngineMemoryConfig(
display_name="my_memory_fact",
ttl="120s",
metadata=metadata,
),
)
assert isinstance(operation, types.AgentEngineMemoryOperation)
assert operation.response.fact == "memory_fact"
Expand All @@ -42,6 +59,7 @@ def test_create_memory_with_ttl(client):
<= operation.response.expire_time
<= operation.response.create_time + datetime.timedelta(seconds=120.5)
)
assert operation.response.metadata == metadata
# Clean up resources.
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)

Expand All @@ -51,7 +69,7 @@ def test_create_memory_with_expire_time(client):
assert isinstance(agent_engine, types.AgentEngine)
assert isinstance(agent_engine.api_resource, types.ReasoningEngine)
expire_time = datetime.datetime(
2026, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)

operation = client.agent_engines.memories.create(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import datetime
import pytest


Expand Down Expand Up @@ -145,6 +146,138 @@ def test_generate_memories_direct_memories_source(client):
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


def test_generate_memories_with_metadata(client):
agent_engine = client.agent_engines.create()
metadata = {
"my_string_key": types.MemoryMetadataValue(
string_value="my_string_value"
),
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
"my_timestamp_key": types.MemoryMetadataValue(
timestamp_value=datetime.datetime(
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)
),
}
# Reuse the same content and scope for all generation requests to ensure
# that the same memory is updated.
direct_memories_source = types.GenerateMemoriesRequestDirectMemoriesSource(
direct_memories=[
types.GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(
fact="I am a software engineer."
),
]
)
scope = {"user_id": "test-user-id"}

operation = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_memories_source=direct_memories_source,
config=types.GenerateAgentEngineMemoriesConfig(
metadata=metadata
),
)
assert len(operation.response.generated_memories) >= 1
memory = client.agent_engines.memories.get(
name=operation.response.generated_memories[0].memory.name
)
assert memory.metadata == metadata

# Overwrite the metadata.
overwrite_metadata = {
"my_string_key": types.MemoryMetadataValue(string_value="new_value"),
}
operation = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_memories_source=direct_memories_source,
config=types.GenerateAgentEngineMemoriesConfig(
metadata=overwrite_metadata,
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.OVERWRITE,
),
)
assert len(operation.response.generated_memories) >= 1
assert (
operation.response.generated_memories[0].action
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
)
memory = client.agent_engines.memories.get(
name=operation.response.generated_memories[0].memory.name
)
assert memory.metadata == overwrite_metadata

# Merge the metadata.
new_metadata = {
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
}
operation = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_memories_source=direct_memories_source,
config=types.GenerateAgentEngineMemoriesConfig(
metadata=new_metadata,
metadata_merge_strategy=types.MemoryMetadataMergeStrategy.MERGE,
),
)
assert len(operation.response.generated_memories) >= 1
assert (
operation.response.generated_memories[0].action
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
)
memory = client.agent_engines.memories.get(
name=operation.response.generated_memories[0].memory.name
)
assert memory.metadata == {**overwrite_metadata, **new_metadata}

# Restrict consolidation based on metadata values. For the first request,
# there's no existing memories that match the metadata, so a new memory is
# created.
restricted_metadata = {
"my_string_key": types.MemoryMetadataValue(string_value="new_value2"),
}
operation = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_memories_source=direct_memories_source,
config=types.GenerateAgentEngineMemoriesConfig(
metadata=restricted_metadata,
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
),
)
assert len(operation.response.generated_memories) == 1
# Metadata doesn't match existing memory, so a new memory is created.
assert (
operation.response.generated_memories[0].action
== types.GenerateMemoriesResponseGeneratedMemoryAction.CREATED
)
memory = client.agent_engines.memories.get(
name=operation.response.generated_memories[0].memory.name
)
assert memory.metadata == restricted_metadata

# Send a second request where the metadata matches only one of the existing
# memories.
operation = client.agent_engines.memories.generate(
name=agent_engine.api_resource.name,
scope=scope,
direct_memories_source=direct_memories_source,
config=types.GenerateAgentEngineMemoriesConfig(
metadata=restricted_metadata,
metadata_merge_strategy="REQUIRE_EXACT_MATCH",
),
)
assert len(operation.response.generated_memories) == 1
assert (
operation.response.generated_memories[0].action
== types.GenerateMemoriesResponseGeneratedMemoryAction.UPDATED
)
assert operation.response.generated_memories[0].memory.name == memory.name

client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#
# pylint: disable=protected-access,bad-continuation,missing-function-docstring

import datetime
import pytest


Expand Down Expand Up @@ -115,6 +116,57 @@ def test_retrieve_memories_with_simple_retrieval_params(client):
agent_engine.delete(force=True)


def test_retrieve_memories_with_metadata(client):
agent_engine = client.agent_engines.create()
metadata = {
"my_string_key": types.MemoryMetadataValue(
string_value="my_string_value"
),
"my_double_key": types.MemoryMetadataValue(double_value=123.456),
"my_boolean_key": types.MemoryMetadataValue(bool_value=True),
"my_timestamp_key": types.MemoryMetadataValue(
timestamp_value=datetime.datetime(
2027, 1, 1, 12, 30, 00, tzinfo=datetime.timezone.utc
)
),
}
scope = {"user_id": "123"}
client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_1",
scope=scope,
)
operation = client.agent_engines.memories.create(
name=agent_engine.api_resource.name,
fact="memory_fact_2",
scope=scope,
config={"metadata": metadata},
)
memory_name2 = operation.response.name

results = client.agent_engines.memories.retrieve(
name=agent_engine.api_resource.name,
scope=scope,
config={
"filter_groups": [
{
"filters": [
{
"key": "my_string_key",
"value": {"string_value": "my_string_value"}
}
]
}
],
},
)
assert len(results) == 1
assert results[0].memory.name == memory_name2

# Clean up resources.
agent_engine.delete(force=True)


pytestmark = pytest_helper.setup(
file=__file__,
globals_for_file=globals(),
Expand Down
23 changes: 23 additions & 0 deletions vertexai/_genai/memories.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def _AgentEngineMemoryConfig_to_vertex(
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
)

if getv(from_object, ["metadata"]) is not None:
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))

return to_object


Expand Down Expand Up @@ -153,6 +156,16 @@ def _GenerateAgentEngineMemoriesConfig_to_vertex(
getv(from_object, ["disable_memory_revisions"]),
)

if getv(from_object, ["metadata"]) is not None:
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))

if getv(from_object, ["metadata_merge_strategy"]) is not None:
setv(
parent_object,
["metadataMergeStrategy"],
getv(from_object, ["metadata_merge_strategy"]),
)

return to_object


Expand Down Expand Up @@ -316,6 +329,13 @@ def _RetrieveAgentEngineMemoriesConfig_to_vertex(
if getv(from_object, ["filter"]) is not None:
setv(parent_object, ["filter"], getv(from_object, ["filter"]))

if getv(from_object, ["filter_groups"]) is not None:
setv(
parent_object,
["filterGroups"],
[item for item in getv(from_object, ["filter_groups"])],
)

return to_object


Expand Down Expand Up @@ -413,6 +433,9 @@ def _UpdateAgentEngineMemoryConfig_to_vertex(
parent_object, ["topics"], [item for item in getv(from_object, ["topics"])]
)

if getv(from_object, ["metadata"]) is not None:
setv(parent_object, ["metadata"], getv(from_object, ["metadata"]))

if getv(from_object, ["update_mask"]) is not None:
setv(
parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"])
Expand Down
22 changes: 22 additions & 0 deletions vertexai/_genai/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,17 @@
from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict
from .common import MemoryBankCustomizationConfigMemoryTopicOrDict
from .common import MemoryBankCustomizationConfigOrDict
from .common import MemoryConjunctionFilter
from .common import MemoryConjunctionFilterDict
from .common import MemoryConjunctionFilterOrDict
from .common import MemoryDict
from .common import MemoryFilter
from .common import MemoryFilterDict
from .common import MemoryFilterOrDict
from .common import MemoryMetadataMergeStrategy
from .common import MemoryMetadataValue
from .common import MemoryMetadataValueDict
from .common import MemoryMetadataValueOrDict
from .common import MemoryOrDict
from .common import MemoryRevision
from .common import MemoryRevisionDict
Expand Down Expand Up @@ -613,6 +623,7 @@
from .common import ObservabilityEvalCase
from .common import ObservabilityEvalCaseDict
from .common import ObservabilityEvalCaseOrDict
from .common import Operator
from .common import OptimizeConfig
from .common import OptimizeConfigDict
from .common import OptimizeConfigOrDict
Expand Down Expand Up @@ -1523,6 +1534,15 @@
"RetrieveMemoriesRequestSimpleRetrievalParams",
"RetrieveMemoriesRequestSimpleRetrievalParamsDict",
"RetrieveMemoriesRequestSimpleRetrievalParamsOrDict",
"MemoryMetadataValue",
"MemoryMetadataValueDict",
"MemoryMetadataValueOrDict",
"MemoryFilter",
"MemoryFilterDict",
"MemoryFilterOrDict",
"MemoryConjunctionFilter",
"MemoryConjunctionFilterDict",
"MemoryConjunctionFilterOrDict",
"RetrieveAgentEngineMemoriesConfig",
"RetrieveAgentEngineMemoriesConfigDict",
"RetrieveAgentEngineMemoriesConfigOrDict",
Expand Down Expand Up @@ -1909,6 +1929,7 @@
"IdentityType",
"AgentServerMode",
"ManagedTopicEnum",
"Operator",
"Language",
"MachineConfig",
"State",
Expand All @@ -1917,6 +1938,7 @@
"RubricContentType",
"EvaluationRunState",
"OptimizeTarget",
"MemoryMetadataMergeStrategy",
"GenerateMemoriesResponseGeneratedMemoryAction",
"PromptOptimizerMethod",
"PromptData",
Expand Down
Loading
Loading