Skip to content

Commit 9cbed6d

Browse files
authored
feat: return ids when add memory (#119)
* feat: return ids when add memory * feat: add example for how to use tree.add * tests: add tests for tree memory add * tests: add tests for tree memory add
1 parent b4d5a88 commit 9cbed6d

File tree

5 files changed

+65
-13
lines changed

5 files changed

+65
-13
lines changed

examples/core_memories/tree_textual_memory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def embed_memory_item(memory: str) -> list[float]:
186186
memory = reader.get_memory(scene_data, type="chat", info={"user_id": "1234", "session_id": "2222"})
187187

188188
for m_list in memory:
189-
my_tree_textual_memory.add(m_list)
189+
added_ids = my_tree_textual_memory.add(m_list)
190+
for i, id in enumerate(added_ids):
191+
print(f"{i}'th added result is:" + my_tree_textual_memory.get(id).memory)
190192
my_tree_textual_memory.memory_manager.wait_reorganizer()
191193

192194
time.sleep(60)
@@ -217,7 +219,7 @@ def embed_memory_item(memory: str) -> list[float]:
217219
doc_memory = reader.get_memory(doc_paths, "doc", info={"user_id": "1111", "session_id": "2222"})
218220

219221
for m_list in doc_memory:
220-
my_tree_textual_memory.add(m_list)
222+
added_ids = my_tree_textual_memory.add(m_list)
221223
my_tree_textual_memory.memory_manager.wait_reorganizer()
222224

223225
results = my_tree_textual_memory.search(

src/memos/memories/textual/tree.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, config: TreeTextMemoryConfig):
5757
else:
5858
logger.info("No internet retriever configured")
5959

60-
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> None:
60+
def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]:
6161
"""Add memories.
6262
Args:
6363
memories: List of TextualMemoryItem objects or dictionaries to add.
@@ -67,7 +67,7 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> None:
6767
plan = plan_memory_operations(memory_items, metadata, self.graph_store)
6868
execute_plan(memory_items, metadata, plan, self.graph_store)
6969
"""
70-
self.memory_manager.add(memories)
70+
return self.memory_manager.add(memories)
7171

7272
def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
7373
self.memory_manager.replace_working_memory(memories)

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,18 @@ def __init__(
4949
)
5050
self._merged_threshold = merged_threshold
5151

52-
def add(self, memories: list[TextualMemoryItem]) -> None:
52+
def add(self, memories: list[TextualMemoryItem]) -> list[str]:
5353
"""
5454
Add new memories in parallel to different memory types (WorkingMemory, LongTermMemory, UserMemory).
5555
"""
56+
added_ids: list[str] = []
57+
5658
with ThreadPoolExecutor(max_workers=8) as executor:
57-
futures = [executor.submit(self._process_memory, memory) for memory in memories]
59+
futures = {executor.submit(self._process_memory, m): m for m in memories}
5860
for future in as_completed(futures):
5961
try:
60-
future.result()
62+
ids = future.result()
63+
added_ids.extend(ids)
6164
except Exception as e:
6265
logger.exception("Memory processing error: ", exc_info=e)
6366

@@ -72,6 +75,7 @@ def add(self, memories: list[TextualMemoryItem]) -> None:
7275
)
7376

7477
self._refresh_memory_size()
78+
return added_ids
7579

7680
def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
7781
"""
@@ -113,17 +117,23 @@ def _process_memory(self, memory: TextualMemoryItem):
113117
Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory).
114118
This method runs asynchronously to process each memory item.
115119
"""
120+
ids = []
121+
116122
# Add to WorkingMemory
117-
self._add_memory_to_db(memory, "WorkingMemory")
123+
working_id = self._add_memory_to_db(memory, "WorkingMemory")
124+
ids.append(working_id)
118125

119126
# Add to LongTermMemory and UserMemory
120127
if memory.metadata.memory_type in ["LongTermMemory", "UserMemory"]:
121-
self._add_to_graph_memory(
128+
added_id = self._add_to_graph_memory(
122129
memory=memory,
123130
memory_type=memory.metadata.memory_type,
124131
)
132+
ids.append(added_id)
133+
134+
return ids
125135

126-
def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str):
136+
def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str) -> str:
127137
"""
128138
Add a single memory item to the graph store, with FIFO logic for WorkingMemory.
129139
"""
@@ -135,6 +145,7 @@ def _add_memory_to_db(self, memory: TextualMemoryItem, memory_type: str):
135145

136146
# Insert node into graph
137147
self.graph_store.add_node(working_memory.id, working_memory.memory, metadata)
148+
return working_memory.id
138149

139150
def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str):
140151
"""
@@ -159,7 +170,7 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str):
159170
)
160171

161172
if similar_nodes and similar_nodes[0]["score"] > self._merged_threshold:
162-
self._merge(memory, similar_nodes)
173+
return self._merge(memory, similar_nodes)
163174
else:
164175
node_id = str(uuid.uuid4())
165176
# Step 2: Add new node to graph
@@ -172,8 +183,9 @@ def _add_to_graph_memory(self, memory: TextualMemoryItem, memory_type: str):
172183
after_node=[node_id],
173184
)
174185
)
186+
return node_id
175187

176-
def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> None:
188+
def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> str:
177189
"""
178190
TODO: Add node traceability support by optionally preserving source nodes and linking them with MERGED_FROM edges.
179191
@@ -200,7 +212,9 @@ def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> N
200212
merged_background = f"{original_meta.background}\n⟵MERGED⟶\n{source_meta.background}"
201213
merged_embedding = self.embedder.embed([merged_text])[0]
202214

203-
merged_confidence = float((original_meta.confidence + source_meta.confidence) / 2)
215+
original_conf = original_meta.confidence or 0.0
216+
source_conf = source_meta.confidence or 0.0
217+
merged_confidence = float((original_conf + source_conf) / 2)
204218
merged_usage = list(set((original_meta.usage or []) + (source_meta.usage or [])))
205219

206220
# Create new merged node
@@ -243,6 +257,7 @@ def _merge(self, source_node: TextualMemoryItem, similar_nodes: list[dict]) -> N
243257
after_node=[merged_id],
244258
)
245259
)
260+
return merged_id
246261

247262
def _inherit_edges(self, from_id: str, to_id: str) -> None:
248263
"""

tests/memories/textual/test_tree.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,27 @@ def test_drop_creates_backup_and_cleans(mock_tree_text_memory):
138138
mock_tree_text_memory.dump.assert_called_once()
139139
mock_tree_text_memory._cleanup_old_backups.assert_called_once()
140140
mock_tree_text_memory.graph_store.drop_database.assert_called_once()
141+
142+
143+
def test_add_returns_ids(mock_tree_text_memory):
144+
# Mock the memory_manager.add to return specific IDs
145+
dummy_ids = ["id1", "id2"]
146+
mock_tree_text_memory.memory_manager.add = MagicMock(return_value=dummy_ids)
147+
148+
mock_items = [
149+
TextualMemoryItem(
150+
id=str(uuid.uuid4()),
151+
memory="Memory 1",
152+
metadata=TreeNodeTextualMemoryMetadata(updated_at=None),
153+
),
154+
TextualMemoryItem(
155+
id=str(uuid.uuid4()),
156+
memory="Memory 2",
157+
metadata=TreeNodeTextualMemoryMetadata(updated_at=None),
158+
),
159+
]
160+
161+
result = mock_tree_text_memory.add(mock_items)
162+
163+
assert result == dummy_ids
164+
mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items)

tests/memories/textual/test_tree_manager.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,14 @@ def test_ensure_structure_path_reuses_existing(memory_manager, mock_graph_store)
145145
meta = TreeNodeTextualMemoryMetadata(key="hobby")
146146
node_id = memory_manager._ensure_structure_path("UserMemory", meta)
147147
assert node_id == "existing_node_id"
148+
149+
150+
def test_add_returns_written_node_ids(memory_manager):
151+
memory = TextualMemoryItem(
152+
memory="test memory",
153+
metadata=TreeNodeTextualMemoryMetadata(embedding=[0.1] * 5, memory_type="UserMemory"),
154+
)
155+
ids = memory_manager.add([memory])
156+
assert isinstance(ids, list)
157+
assert all(isinstance(i, str) for i in ids)
158+
assert len(ids) > 0

0 commit comments

Comments
 (0)