diff --git a/examples/agent_mem_example.py b/examples/agent_mem_example.py new file mode 100644 index 0000000000..663c92d8b9 --- /dev/null +++ b/examples/agent_mem_example.py @@ -0,0 +1,55 @@ +import os +from typing import List + +from haystack.components.agents import Agent +from haystack.components.generators.chat import OpenAIChatGenerator +from haystack.components.memory import Mem0MemoryStore +from haystack.dataclasses import ChatMessage +from haystack.tools import tool +from haystack.tools.tool import Tool + + +@tool +def save_user_preference(preference_type: str, preference_value: str) -> str: + """Save user preferences that should be remembered""" + return f"✅ Saved preference: {preference_type} = {preference_value}" + + +@tool +def get_recommendation(category: str) -> str: + """Get personalized recommendations based on user preferences""" + recommendations = { + "food": "Based on your preferences, try the Mediterranean cuisine!", + "music": "I recommend some jazz playlists for you!", + "books": "You might enjoy science fiction novels!", + } + return recommendations.get(category, "I'll learn your preferences to give better recommendations!") + + +# User initializes the memory store with config and user_id +memory_store = Mem0MemoryStore(user_id="test_123", api_key=os.getenv("MEM0_API_KEY")) +# User may use class method to set search criteria +memory_store.set_search_criteria(filters={"categories": {"contains": "movie"}}) + + +# Agent Setup +agent = Agent( + chat_generator=OpenAIChatGenerator(), + memory_store=memory_store, + tools=[save_user_preference, get_recommendation], + system_prompt=""" + You are a personal assistant with memory capabilities. + Use the provided memories to personalize your responses and remember user context. + When users share preferences, use the save_user_preference tool. + When asked for recommendations, use the get_recommendation tool. + Be conversational and reference previous interactions when relevant. + """, + exit_conditions=["text"], + state_schema={"text": {"type": str}}, +) + +# Run the Agent +agent.warm_up() +response = agent.run(messages=[ChatMessage.from_user("Recommend me a movie to watch on Friday night.")]) + +print(response["messages"]) diff --git a/haystack/components/agents/agent.py b/haystack/components/agents/agent.py index faa019671e..dca9e6b72e 100644 --- a/haystack/components/agents/agent.py +++ b/haystack/components/agents/agent.py @@ -8,6 +8,7 @@ from haystack import logging, tracing from haystack.components.generators.chat.types import ChatGenerator +from haystack.components.memory import Mem0MemoryStore from haystack.components.tools import ToolInvoker from haystack.core.component.component import component from haystack.core.errors import PipelineRuntimeError @@ -105,6 +106,7 @@ def __init__( streaming_callback: Optional[StreamingCallbackT] = None, raise_on_tool_invocation_failure: bool = False, tool_invoker_kwargs: Optional[dict[str, Any]] = None, + memory_store: Optional[Mem0MemoryStore] = None, ) -> None: """ Initialize the agent component. @@ -123,6 +125,8 @@ def __init__( :param raise_on_tool_invocation_failure: Should the agent raise an exception when a tool invocation fails? If set to False, the exception will be turned into a chat message and passed to the LLM. :param tool_invoker_kwargs: Additional keyword arguments to pass to the ToolInvoker. + :param memory_store: The memory store to use for the agent. MemoryStore can be configured with + MemoryConfig to provide user_id and database configuration. :raises TypeError: If the chat_generator does not support tools parameter in its run method. :raises ValueError: If the exit_conditions are not valid. """ @@ -162,7 +166,7 @@ def __init__( self.max_agent_steps = max_agent_steps self.raise_on_tool_invocation_failure = raise_on_tool_invocation_failure self.streaming_callback = streaming_callback - + self.memory_store = memory_store output_types = {"last_message": ChatMessage} for param, config in self.state_schema.items(): output_types[param] = config["type"] @@ -216,6 +220,7 @@ def to_dict(self) -> dict[str, Any]: streaming_callback=serialize_callable(self.streaming_callback) if self.streaming_callback else None, raise_on_tool_invocation_failure=self.raise_on_tool_invocation_failure, tool_invoker_kwargs=self.tool_invoker_kwargs, + memory_store=self.memory_store.to_dict(), ) @classmethod @@ -490,10 +495,19 @@ def run( # noqa: PLR0915 :raises RuntimeError: If the Agent component wasn't warmed up before calling `run()`. :raises BreakpointException: If an agent breakpoint is triggered. """ + + agent_memory = [] + + # Retrieve memories from the memory store + if self.memory_store: + agent_memory = self.memory_store.search_memories(query=messages[-1].text) + + combined_messages = messages + agent_memory + # We pop parent_snapshot from kwargs to avoid passing it into State. parent_snapshot = kwargs.pop("parent_snapshot", None) agent_inputs = { - "messages": messages, + "messages": combined_messages, "streaming_callback": streaming_callback, "break_point": break_point, "snapshot": snapshot, @@ -507,7 +521,7 @@ def run( # noqa: PLR0915 ) else: exe_context = self._initialize_fresh_execution( - messages=messages, + messages=combined_messages, streaming_callback=streaming_callback, requires_async=False, system_prompt=system_prompt, @@ -611,6 +625,9 @@ def run( # noqa: PLR0915 result = {**exe_context.state.data} if msgs := result.get("messages"): result["last_message"] = msgs[-1] + + # Add the new conversation as memories to the memory store + self.memory_store.add_memories(result["messages"]) return result async def run_async( diff --git a/haystack/components/memory/mem0_store.py b/haystack/components/memory/mem0_store.py new file mode 100644 index 0000000000..9662cdd924 --- /dev/null +++ b/haystack/components/memory/mem0_store.py @@ -0,0 +1,135 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# SPDX-License-Identifier: Apache-2.0 + +import os +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Optional + +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses.chat_message import ChatMessage, ChatRole +from haystack.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install mem0ai'") as mem0_import: + from mem0 import MemoryClient + + +class Mem0MemoryStore: + """ + A memory store implementation using Mem0 as the backend. + + :param api_key: Mem0 API key (if not provided, uses MEM0_API_KEY environment variable) + :param config: Configuration dictionary for Mem0 client + :param kwargs: Additional configuration parameters for Mem0 client + """ + + def __init__(self, user_id: str, api_key: Optional[str] = None, memory_config: Optional[dict[str, Any]] = None): + mem0_import.check() + self.api_key = api_key or os.getenv("MEM0_API_KEY") + if not self.api_key: + raise ValueError("Mem0 API key must be provided either as parameter or MEM0_API_KEY environment variable") + + self.user_id = user_id + + # If an OpenSearch config is provided, use it to initialize the Mem0 client + if memory_config: + self.client = MemoryClient.from_config(memory_config) + else: + self.client = MemoryClient(api_key=self.api_key) + + # Search criteria is used to store the search criteria for the memory store + # User can set the search criteria using the set_search_criteria method + self.search_criteria = None + + def to_dict(self) -> dict[str, Any]: + """Serialize the store configuration to a dictionary.""" + return default_to_dict(self, api_key=self.api_key, config=self.search_criteria) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "Mem0MemoryStore": + """Deserialize the store from a dictionary.""" + return default_from_dict(cls, data) + + def add_memories(self, messages: list[ChatMessage]) -> list[str]: + """ + Add ChatMessage memories to Mem0. + + :param messages: List of ChatMessage objects with memory metadata + :returns: List of memory IDs for the added messages + """ + added_ids = [] + + for message in messages: + if not message.text: + continue + mem0_message = [{"role": "user", "content": message.text}] + + try: + # Mem0 primarily uses user_id as the main identifier + # org_id and session_id are stored in metadata for filtering + result = self.client.add( + messages=mem0_message, user_id=self.user_id, metadata=message.meta, infer=False + ) + # Mem0 returns different response formats, handle both + memory_id = result.get("id") or result.get("memory_id") or str(result) + added_ids.append(memory_id) + except Exception as e: + raise RuntimeError(f"Failed to add memory message: {e}") from e + + return added_ids + + def set_search_criteria( + self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: Optional[int] = None + ): + """ + Set the memory configuration for the memory store. + """ + self.search_criteria = {"query": query, "filters": filters, "top_k": top_k} + + def search_memories( + self, query: Optional[str] = None, filters: Optional[dict[str, Any]] = None, top_k: int = 10 + ) -> list[ChatMessage]: + """ + Search for memories in Mem0. + + :param query: Text query to search for. If not provided, all memories will be returned. + :param user_id: User identifier for scoping the search + :param filters: Additional filters to apply on search. For more details on mem0 filters, see https://mem0.ai/docs/search/ + :param top_k: Maximum number of results to return + :returns: List of ChatMessage memories matching the criteria + """ + # Prepare filters for Mem0 + search_query = query or self.search_criteria["query"] + search_filters = filters or self.search_criteria["filters"] or {} + search_top_k = top_k or self.search_criteria["top_k"] or 10 + + mem0_filters = {"AND": [{"user_id": self.user_id}, search_filters]} + + try: + if not search_query: + results = self.client.get_all(filters=mem0_filters, top_k=search_top_k) + else: + results = self.client.search( + query=search_query, limit=search_top_k, filters=mem0_filters, user_id=self.user_id + ) + memories = [ + ChatMessage.from_assistant(text=result["memory"], meta=result["metadata"]) for result in results + ] + + return memories + + except Exception as e: + raise RuntimeError(f"Failed to search memories: {e}") from e + + # mem0 doesn't allow passing filter to delete endpoint, + # we can delete all memories for a user by passing the user_id + def delete_all_memories(self, user_id: Optional[str] = None): + """ + Delete memory records from Mem0. + + :param user_id: User identifier for scoping the deletion + """ + try: + self.client.delete_all(user_id=user_id or self.user_id) + except Exception as e: + raise RuntimeError(f"Failed to delete memories for user {user_id}: {e}") from e