Skip to content

Commit b4d5a88

Browse files
authored
feat: recursively cluster nodes to max_cluster_size (#105)
* fix n4j cypher query * feat: add llm extra body * feat: update memory extraction prompt and result parser * fix: evaluation locomo search * ci: fix format and update test * feat: update result json parser * feat: recursively cluster nodes to max_cluster_size * fix: fix template * feat: keep default min-group-size 3 * feat: keep default min-group-size 3
1 parent a0990be commit b4d5a88

File tree

2 files changed

+46
-33
lines changed

2 files changed

+46
-33
lines changed

src/memos/memories/textual/tree_text_memory/organize/reorganizer.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from concurrent.futures import ThreadPoolExecutor, as_completed
77
from queue import PriorityQueue
88
from typing import Literal
9-
9+
from collections import Counter, defaultdict
1010
import numpy as np
1111
import schedule
1212

@@ -378,9 +378,7 @@ def _local_subcluster(self, cluster_nodes: list[GraphDBNode]) -> list[list[Graph
378378

379379
return result_subclusters
380380

381-
def _partition(
382-
self, nodes: list[GraphDBNode], min_cluster_size: int = 3
383-
) -> list[list[GraphDBNode]]:
381+
def _partition(self, nodes, min_cluster_size: int = 3, max_cluster_size: int = 20):
384382
"""
385383
Partition nodes by:
386384
1) Frequent tags (top N & above threshold)
@@ -394,8 +392,6 @@ def _partition(
394392
Returns:
395393
List of clusters, each as a list of GraphDBNode
396394
"""
397-
from collections import Counter, defaultdict
398-
399395
# 1) Count all tags
400396
tag_counter = Counter()
401397
for node in nodes:
@@ -407,7 +403,7 @@ def _partition(
407403
threshold_tags = {tag for tag, count in tag_counter.items() if count >= 50}
408404
frequent_tags = top_n_tags | threshold_tags
409405

410-
# Group nodes by tags, ensure each group is unique internally
406+
# Group nodes by tags
411407
tag_groups = defaultdict(list)
412408

413409
for node in nodes:
@@ -420,48 +416,67 @@ def _partition(
420416
assigned_ids = set()
421417
for tag, group in tag_groups.items():
422418
if len(group) >= min_cluster_size:
423-
filtered_tag_clusters.append(group)
424-
assigned_ids.update(n.id for n in group)
419+
# Split large groups into chunks of at most max_cluster_size
420+
for i in range(0, len(group), max_cluster_size):
421+
sub_group = group[i : i + max_cluster_size]
422+
filtered_tag_clusters.append(sub_group)
423+
assigned_ids.update(n.id for n in sub_group)
425424
else:
426-
logger.info(f"... dropped {tag} ...")
425+
logger.info(f"... dropped tag {tag} due to low size ...")
427426

428427
logger.info(
429428
f"[MixedPartition] Created {len(filtered_tag_clusters)} clusters from tags. "
430429
f"Nodes grouped by tags: {len(assigned_ids)} / {len(nodes)}"
431430
)
432431

433-
# 5) Remaining nodes -> embedding clustering
432+
# Remaining nodes -> embedding clustering
434433
remaining_nodes = [n for n in nodes if n.id not in assigned_ids]
435434
logger.info(
436435
f"[MixedPartition] Remaining nodes for embedding clustering: {len(remaining_nodes)}"
437436
)
438437

439438
embedding_clusters = []
440-
if remaining_nodes:
441-
x = np.array([n.metadata.embedding for n in remaining_nodes if n.metadata.embedding])
442-
k = max(1, min(len(remaining_nodes) // min_cluster_size, 20))
443-
if len(x) < k:
444-
k = len(x)
445439

446-
if 1 < k <= len(x):
440+
def recursive_clustering(nodes_list):
441+
"""Recursively split clusters until each is <= max_cluster_size."""
442+
if len(nodes_list) <= max_cluster_size:
443+
return [nodes_list]
444+
445+
# Try kmeans with k = ceil(len(nodes) / max_cluster_size)
446+
x = np.array([n.metadata.embedding for n in nodes_list if n.metadata.embedding])
447+
if len(x) < 2:
448+
return [nodes_list]
449+
450+
k = min(len(x), (len(nodes_list) + max_cluster_size - 1) // max_cluster_size)
451+
k = max(1, min(k, len(x)))
452+
453+
try:
447454
kmeans = MiniBatchKMeans(n_clusters=k, batch_size=256, random_state=42)
448455
labels = kmeans.fit_predict(x)
449456

450457
label_groups = defaultdict(list)
451-
for node, label in zip(remaining_nodes, labels, strict=False):
458+
for node, label in zip(nodes_list, labels, strict=False):
452459
label_groups[label].append(node)
453460

454-
embedding_clusters = list(label_groups.values())
455-
logger.info(
456-
f"[MixedPartition] Created {len(embedding_clusters)} clusters from embedding."
457-
)
458-
else:
459-
embedding_clusters = [remaining_nodes]
461+
result = []
462+
for sub_group in label_groups.values():
463+
result.extend(recursive_clustering(sub_group))
464+
return result
465+
except Exception as e:
466+
logger.warning(f"Clustering failed: {e}, falling back to single cluster.")
467+
return [nodes_list]
468+
469+
if remaining_nodes:
470+
clusters = recursive_clustering(remaining_nodes)
471+
embedding_clusters.extend(clusters)
472+
logger.info(
473+
f"[MixedPartition] Created {len(embedding_clusters)} clusters from embeddings."
474+
)
460475

461-
# Merge all & handle small clusters
476+
# Merge all clusters
462477
all_clusters = filtered_tag_clusters + embedding_clusters
463478

464-
# Optional: merge tiny clusters
479+
# Handle small clusters (< min_cluster_size)
465480
final_clusters = []
466481
small_nodes = []
467482
for group in all_clusters:

src/memos/templates/tree_reorganize_prompts.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,18 @@
2323
5. Summarize all child memory items into one memory item.
2424
2525
Language rules:
26-
- The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input memory items. **如果输入是中文,请输出中文**
26+
- The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input memory items. **如果输入是中文,请输出中文**
2727
- Keep `memory_type` in English.
2828
29-
Language rules:
30-
- The `key`, `value`, `tags`, `background` fields must match the language of the input conversation.
31-
3229
Return valid JSON:
3330
{
34-
"key": <string, a unique, concise memory title>,
31+
"key": <string, a concise title of the `value` field>,
3532
"memory_type": <string, Either "LongTermMemory" or "UserMemory">,
36-
"value": <A detailed, self-contained, and unambiguous memory statement — written in English if the input memory items are in English, or in Chinese if the input is in Chinese>,
33+
"value": <A detailed, self-contained, and unambiguous memory statement, only contain detailed, unaltered information extracted and consolidated from the input `value` fields, do not include summary content — written in English if the input memory items are in English, or in Chinese if the input is in Chinese>,
3734
"tags": <A list of relevant thematic keywords (e.g., ["deadline", "team", "planning"])>,
38-
"summary": <a natural paragraph summarizing the above memories from user's perspective, 120–200 words, same language as the input>
35+
"summary": <a natural paragraph summarizing the above memories from user's perspective, only contain information from the input `summary` fields, 120–200 words, same language as the input>
3936
}
37+
4038
"""
4139

4240
LOCAL_SUBCLUSTER_PROMPT = """You are a memory organization expert.

0 commit comments

Comments
 (0)