diff --git a/examples/mem0_tool_agent.py b/examples/mem0_tool_agent.py index 8369a05..cf049dd 100644 --- a/examples/mem0_tool_agent.py +++ b/examples/mem0_tool_agent.py @@ -1,21 +1,29 @@ """ Mem0 toolkit demo using SpoonReactAI. -This demo requires spoon-toolkit to be installed +This demo requires spoon-toolkit to be installed (pip install spoon-toolkit or local install). """ import asyncio -from typing import Any, Dict, List +from typing import Dict, List, Optional from pydantic import Field from spoon_ai.agents.spoon_react import SpoonReactAI from spoon_ai.chat import ChatBot +from spoon_ai.llm.manager import get_llm_manager +from spoon_ai.memory.utils import extract_memories, extract_first_memory_id from spoon_ai.tools.tool_manager import ToolManager from spoon_ai.tools.base import ToolResult -from spoon_toolkits.memory import AddMemoryTool, SearchMemoryTool, GetAllMemoryTool +from spoon_toolkits.memory import ( + AddMemoryTool, + SearchMemoryTool, + GetAllMemoryTool, + UpdateMemoryTool, + DeleteMemoryTool, +) -USER_ID = "defi_user_002" +USER_ID = "defi_user_005" class DeFiMemoryAgent(SpoonReactAI): @@ -26,14 +34,14 @@ class DeFiMemoryAgent(SpoonReactAI): def model_post_init(self, __context: Any = None) -> None: super().model_post_init(__context) - # Rebuild tools with the injected mem0_config for this agent memory_tools = [ AddMemoryTool(mem0_config=self.mem0_config), SearchMemoryTool(mem0_config=self.mem0_config), GetAllMemoryTool(mem0_config=self.mem0_config), + UpdateMemoryTool(mem0_config=self.mem0_config), + DeleteMemoryTool(mem0_config=self.mem0_config), ] self.available_tools = ToolManager(memory_tools) - # Refresh prompts so SpoonReactAI lists the newly provided tools if hasattr(self, "_refresh_prompts"): self._refresh_prompts() @@ -55,11 +63,10 @@ def build_agent(mem0_cfg: Dict[str, Any]) -> DeFiMemoryAgent: def print_memories(result: ToolResult, label: str) -> None: - if not isinstance(result, ToolResult): - print(f"[Mem0] {label}: error -> {result}") - return - memories: List[str] = result.output.get("memories", []) if result and result.output else [] + memories = extract_memories(result) print(f"[Mem0] {label}:") + if not memories: + print(" (none)") for m in memories: print(f" - {m}") @@ -77,12 +84,26 @@ async def phase_capture(agent: DeFiMemoryAgent) -> None: "and dislike Ethereum gas fees." ), } - ] + ], + "user_id": USER_ID, + "async_mode": False, }, ) + # Verify storage immediately after add to avoid read-after-write surprises + verified: ToolResult = ToolResult() + for attempt in range(3): + verified = await agent.available_tools.execute( + name="get_all_memory", + tool_input={"user_id": USER_ID, "limit": 5}, + ) + if extract_memories(verified): + break + await asyncio.sleep(0.5) + print_memories(verified, "Verification after Phase 1 store") + memories = await agent.available_tools.execute( name="search_memory", - tool_input={"query": "Solana meme coins high risk"}, + tool_input={"query": "Solana meme coins high risk", "user_id": USER_ID}, ) print_memories(memories, "After Phase 1 store") @@ -92,12 +113,12 @@ async def phase_recall(mem0_cfg: Dict[str, Any]) -> None: agent = build_agent(mem0_cfg) memories = await agent.available_tools.execute( name="search_memory", - tool_input={"query": "trading strategy solana meme"}, + tool_input={"query": "trading strategy solana meme", "user_id": USER_ID}, ) print_memories(memories, "Retrieved for Phase 2") -async def phase_update(agent: DeFiMemoryAgent) -> None: +async def phase_update(agent: DeFiMemoryAgent, memory_id: Optional[str]) -> None: print("\n=== Phase 3: Update preferences to safer Arbitrum yield ===") await agent.available_tools.execute( name="add_memory", @@ -109,27 +130,63 @@ async def phase_update(agent: DeFiMemoryAgent) -> None: "I lost too much money. I want to pivot to safe stablecoin yield farming on Arbitrum now." ), } - ] + ], + "user_id": USER_ID, + "async_mode": False, }, ) + update_result = await agent.available_tools.execute( + name="update_memory", + tool_input={ + "memory_id": memory_id, + "text": "User pivoted to safer Arbitrum stablecoin yield farming with low risk.", + "user_id": USER_ID, + }, + ) + print(f"[Mem0] Update result: {update_result}") memories = await agent.available_tools.execute( name="search_memory", - tool_input={"query": "stablecoin yield chain choice"}, + tool_input={"query": "stablecoin yield chain choice", "user_id": USER_ID}, ) print_memories(memories, "Retrieved after update (Phase 3)") +async def phase_cleanup(agent: DeFiMemoryAgent, memory_id: Optional[str]) -> None: + print("\n=== Phase 4: Clean up a memory entry ===") + delete_result = await agent.available_tools.execute( + name="delete_memory", + tool_input={"memory_id": memory_id, "user_id": USER_ID}, + ) + print(f"[Mem0] Delete result: {delete_result}") + remaining = await agent.available_tools.execute( + name="get_all_memory", + tool_input={"limit": 5, "user_id": USER_ID}, + ) + print_memories(remaining, "Remaining memories after delete") + + async def main() -> None: mem0_cfg = { "user_id": USER_ID, "metadata": {"project": "defi-investment-advisor"}, "async_mode": False, # synchronous writes so the next search sees new data } - agent = build_agent(mem0_cfg) - await phase_capture(agent) - await phase_recall(mem0_cfg) - await phase_update(agent) - + + try: + agent = build_agent(mem0_cfg) + await phase_capture(agent) + await phase_recall(mem0_cfg) + + all_memories = await agent.available_tools.execute( + name="get_all_memory", tool_input={"limit": 5, "user_id": USER_ID} + ) + print_memories(all_memories, "All memories before update/delete") + first_id = extract_first_memory_id(all_memories) + + await phase_update(agent, first_id) + await phase_cleanup(agent, first_id) + finally: + await get_llm_manager().cleanup() if __name__ == "__main__": diff --git a/spoon_ai/memory/utils.py b/spoon_ai/memory/utils.py new file mode 100644 index 0000000..c37470f --- /dev/null +++ b/spoon_ai/memory/utils.py @@ -0,0 +1,76 @@ +""" +Memory helpers shared across Mem0 demos and utilities. +""" + +from typing import Any, List, Optional + +__all__ = ["extract_memories", "extract_first_memory_id"] + + +def _unwrap_output(result: Any) -> Any: + """Extract the underlying payload from a ToolResult-like object or raw response.""" + if hasattr(result, "output"): + return getattr(result, "output") + return result + + +def extract_memories(result: Any) -> List[str]: + """ + Normalize Mem0 search/get responses into a list of memory strings. + Supports common shapes: {"memories": [...]}, {"results": [...]}, {"data": [...]}, list, or scalar. + """ + data = _unwrap_output(result) + if not data: + return [] + + if isinstance(data, dict): + if isinstance(data.get("memories"), list): + items = data.get("memories", []) + elif isinstance(data.get("results"), list): + items = data.get("results", []) + elif isinstance(data.get("data"), list): + items = data.get("data", []) + else: + items = [data] + elif isinstance(data, list): + items = data + else: + items = [data] + + extracted: List[str] = [] + for item in items: + if isinstance(item, str): + extracted.append(item) + elif isinstance(item, dict): + text = item.get("memory") or item.get("text") or item.get("content") or item.get("value") + if text: + extracted.append(str(text)) + return extracted + + +def extract_first_memory_id(result: Any) -> Optional[str]: + """ + Pull the first memory id from Mem0 responses. + Supports common id fields: id, _id, memory_id, uuid. + """ + data = _unwrap_output(result) + if not data: + return None + + candidates = [] + if isinstance(data, dict): + if isinstance(data.get("results"), list): + candidates = data["results"] + elif isinstance(data.get("memories"), list): + candidates = data["memories"] + elif isinstance(data.get("data"), list): + candidates = data["data"] + elif isinstance(data, list): + candidates = data + + for item in candidates: + if isinstance(item, dict): + mem_id = item.get("id") or item.get("_id") or item.get("memory_id") or item.get("uuid") + if mem_id: + return str(mem_id) + return None