Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions examples/agent_mem_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
from typing import List

from haystack.components.agents import Agent
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.memory import Mem0MemoryStore
from haystack.dataclasses import ChatMessage
from haystack.tools import tool
from haystack.tools.tool import Tool


@tool
def save_user_preference(preference_type: str, preference_value: str) -> str:
"""Save user preferences that should be remembered"""
return f"✅ Saved preference: {preference_type} = {preference_value}"


@tool
def get_recommendation(category: str) -> str:
"""Get personalized recommendations based on user preferences"""
recommendations = {
"food": "Based on your preferences, try the Mediterranean cuisine!",
"music": "I recommend some jazz playlists for you!",
"books": "You might enjoy science fiction novels!",
}
return recommendations.get(category, "I'll learn your preferences to give better recommendations!")


# User initializes the memory store with config and user_id
memory_store = Mem0MemoryStore(user_id="test_123", api_key=os.getenv("MEM0_API_KEY"))
# User may use class method to set search criteria
memory_store.set_search_criteria(filters={"categories": {"contains": "movie"}})


# Agent Setup
agent = Agent(
chat_generator=OpenAIChatGenerator(),
memory_store=memory_store,
tools=[save_user_preference, get_recommendation],
system_prompt="""
You are a personal assistant with memory capabilities.
Use the provided memories to personalize your responses and remember user context.
When users share preferences, use the save_user_preference tool.
When asked for recommendations, use the get_recommendation tool.
Be conversational and reference previous interactions when relevant.
""",
exit_conditions=["text"],
state_schema={"text": {"type": str}},
)

# Run the Agent
agent.warm_up()
response = agent.run(messages=[ChatMessage.from_user("Recommend me a movie to watch on Friday night.")])

print(response["messages"])
23 changes: 20 additions & 3 deletions haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from haystack import logging, tracing
from haystack.components.generators.chat.types import ChatGenerator
from haystack.components.memory import Mem0MemoryStore
from haystack.components.tools import ToolInvoker
from haystack.core.component.component import component
from haystack.core.errors import PipelineRuntimeError
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
streaming_callback: Optional[StreamingCallbackT] = None,
raise_on_tool_invocation_failure: bool = False,
tool_invoker_kwargs: Optional[dict[str, Any]] = None,
memory_store: Optional[Mem0MemoryStore] = None,
) -> None:
"""
Initialize the agent component.
Expand All @@ -123,6 +125,8 @@ def __init__(
:param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails?
If set to False, the exception will be turned into a chat message and passed to the LLM.
:param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker.
:param memory_store: The memory store to use for the agent. MemoryStore can be configured with
MemoryConfig to provide user_id and database configuration.
:raises TypeError: If the chat_generator does not support tools parameter in its run method.
:raises ValueError: If the exit_conditions are not valid.
"""
Expand Down Expand Up @@ -162,7 +166,7 @@ def __init__(
self.max_agent_steps = max_agent_steps
self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure
self.streaming_callback = streaming_callback

self.memory_store = memory_store
output_types = {"last_message": ChatMessage}
for param, config in self.state_schema.items():
output_types[param] = config["type"]
Expand Down Expand Up @@ -216,6 +220,7 @@ def to_dict(self) -> dict[str, Any]:
streaming_callback=serialize_callable(self.streaming_callback) if self.streaming_callback else None,
raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure,
tool_invoker_kwargs=self.tool_invoker_kwargs,
memory_store=self.memory_store.to_dict(),
)

@classmethod
Expand Down Expand Up @@ -490,10 +495,19 @@ def run( # noqa: PLR0915
:raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`.
:raises BreakpointException: If an agent breakpoint is triggered.
"""

agent_memory = []

# Retrieve memories from the memory store
if self.memory_store:
agent_memory = self.memory_store.search_memories(query=messages[-1].text)

combined_messages = messages + agent_memory

# We pop parent_snapshot from kwargs to avoid passing it into State.
parent_snapshot = kwargs.pop("parent_snapshot", None)
agent_inputs = {
"messages": messages,
"messages": combined_messages,
"streaming_callback": streaming_callback,
"break_point": break_point,
"snapshot": snapshot,
Expand All @@ -507,7 +521,7 @@ def run( # noqa: PLR0915
)
else:
exe_context = self._initialize_fresh_execution(
messages=messages,
messages=combined_messages,
streaming_callback=streaming_callback,
requires_async=False,
system_prompt=system_prompt,
Expand Down Expand Up @@ -611,6 +625,9 @@ def run( # noqa: PLR0915
result = {**exe_context.state.data}
if msgs := result.get("messages"):
result["last_message"] = msgs[-1]

# Add the new conversation as memories to the memory store
self.memory_store.add_memories(result["messages"])
return result

async def run_async(
Expand Down
135 changes: 135 additions & 0 deletions haystack/components/memory/mem0_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
# SPDX-License-Identifier: Apache-2.0

import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional

from haystack import default_from_dict, default_to_dict
from haystack.dataclasses.chat_message import ChatMessage, ChatRole
from haystack.lazy_imports import LazyImport

with LazyImport(message="Run 'pip install mem0ai'") as mem0_import:
from mem0 import MemoryClient


class Mem0MemoryStore:
"""
A memory store implementation using Mem0 as the backend.

:param api_key: Mem0 API key (if not provided, uses MEM0_API_KEY environment variable)
:param config: Configuration dictionary for Mem0 client
:param kwargs: Additional configuration parameters for Mem0 client
"""

def __init__(self, user_id: str, api_key: Optional[str] = None, memory_config: Optional[dict[str, Any]] = None):
mem0_import.check()
self.api_key = api_key or os.getenv("MEM0_API_KEY")
if not self.api_key:
raise ValueError("Mem0 API key must be provided either as parameter or MEM0_API_KEY environment variable")

self.user_id = user_id

# If an OpenSearch config is provided, use it to initialize the Mem0 client
if memory_config:
self.client = MemoryClient.from_config(memory_config)
else:
self.client = MemoryClient(api_key=self.api_key)

# Search criteria is used to store the search criteria for the memory store
# User can set the search criteria using the set_search_criteria method
self.search_criteria = None

def to_dict(self) -> dict[str, Any]:
"""Serialize the store configuration to a dictionary."""
return default_to_dict(self, api_key=self.api_key, config=self.search_criteria)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "Mem0MemoryStore":
"""Deserialize the store from a dictionary."""
return default_from_dict(cls, data)

def add_memories(self, messages: list[ChatMessage]) -> list[str]:
"""
Add ChatMessage memories to Mem0.

:param messages: List of ChatMessage objects with memory metadata
:returns: List of memory IDs for the added messages
"""
added_ids = []

for message in messages:
if not message.text:
continue
mem0_message = [{"role": "user", "content": message.text}]

try:
# Mem0 primarily uses user_id as the main identifier
# org_id and session_id are stored in metadata for filtering
result = self.client.add(
messages=mem0_message, user_id=self.user_id, metadata=message.meta, infer=False
)
# Mem0 returns different response formats, handle both
memory_id = result.get("id") or result.get("memory_id") or str(result)
added_ids.append(memory_id)
except Exception as e:
raise RuntimeError(f"Failed to add memory message: {e}") from e

return added_ids

def set_search_criteria(
self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: Optional[int] = None
):
"""
Set the memory configuration for the memory store.
"""
self.search_criteria = {"query": query, "filters": filters, "top_k": top_k}

def search_memories(
self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: int = 10
) -> list[ChatMessage]:
"""
Search for memories in Mem0.

:param query: Text query to search for. If not provided, all memories will be returned.
:param user_id: User identifier for scoping the search
:param filters: Additional filters to apply on search. For more details on mem0 filters, see https://mem0.ai/docs/search/
:param top_k: Maximum number of results to return
:returns: List of ChatMessage memories matching the criteria
"""
# Prepare filters for Mem0
search_query = query or self.search_criteria["query"]
search_filters = filters or self.search_criteria["filters"] or {}
search_top_k = top_k or self.search_criteria["top_k"] or 10

mem0_filters = {"AND": [{"user_id": self.user_id}, search_filters]}

try:
if not search_query:
results = self.client.get_all(filters=mem0_filters, top_k=search_top_k)
else:
results = self.client.search(
query=search_query, limit=search_top_k, filters=mem0_filters, user_id=self.user_id
)
memories = [
ChatMessage.from_assistant(text=result["memory"], meta=result["metadata"]) for result in results
]

return memories

except Exception as e:
raise RuntimeError(f"Failed to search memories: {e}") from e

# mem0 doesn't allow passing filter to delete endpoint,
# we can delete all memories for a user by passing the user_id
def delete_all_memories(self, user_id: Optional[str] = None):
"""
Delete memory records from Mem0.

:param user_id: User identifier for scoping the deletion
"""
try:
self.client.delete_all(user_id=user_id or self.user_id)
except Exception as e:
raise RuntimeError(f"Failed to delete memories for user {user_id}: {e}") from e
Loading