Skip to content

Commit 8f87b33

Browse files
authored
feat: chat bot api (#294)
* fix: add safe guard when parsing node memory * feat: add filter as a parameter in tree-text searcher * feat: add filter for user and long-term memory * feat: add filter in working memory * add filter in task-parser * feat: only mix-retrieve for vector-recall; TODO: mix reranker * feat: add 'session_id' as an optional parameter for product api * feat: api 1.0 finish * maintain: update gitignore * maintain: update gitignore * feat: add 'type' in TextualMemory Sources * feat: add annotation to item * fix: add session_id to product add * fix: test
1 parent 7b245b9 commit 8f87b33

File tree

17 files changed

+336
-89
lines changed

17 files changed

+336
-89
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ tmp/
88
# evaluation data
99
*.csv
1010
*.jsonl
11+
**settings.json**
1112
evaluation/*tmp/
1213
evaluation/results
1314
evaluation/.env
@@ -19,7 +20,7 @@ evaluation/scripts/personamem
1920

2021
# benchmarks
2122
benchmarks/
22-
23+
2324
# Byte-compiled / optimized / DLL files
2425
__pycache__/
2526
*.py[cod]

src/memos/api/product_models.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,14 @@
11
import uuid
22

3-
from typing import Generic, Literal, TypeAlias, TypeVar
3+
from typing import Generic, Literal, TypeVar
44

55
from pydantic import BaseModel, Field
6-
from typing_extensions import TypedDict
76

7+
# Import message types from core types module
8+
from memos.types import MessageDict
89

9-
T = TypeVar("T")
10-
11-
12-
# ─── Message Types ──────────────────────────────────────────────────────────────
13-
14-
# Chat message roles
15-
MessageRole: TypeAlias = Literal["user", "assistant", "system"]
1610

17-
18-
# Message structure
19-
class MessageDict(TypedDict):
20-
"""Typed dictionary for chat message dictionaries."""
21-
22-
role: MessageRole
23-
content: str
11+
T = TypeVar("T")
2412

2513

2614
class BaseRequest(BaseModel):
@@ -86,6 +74,7 @@ class ChatRequest(BaseRequest):
8674
history: list[MessageDict] | None = Field(None, description="Chat history")
8775
internet_search: bool = Field(True, description="Whether to use internet search")
8876
moscube: bool = Field(False, description="Whether to use MemOSCube")
77+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
8978

9079

9180
class ChatCompleteRequest(BaseRequest):
@@ -100,6 +89,7 @@ class ChatCompleteRequest(BaseRequest):
10089
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
10190
top_k: int = Field(10, description="Number of results to return")
10291
threshold: float = Field(0.5, description="Threshold for filtering references")
92+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
10393

10494

10595
class UserCreate(BaseRequest):
@@ -161,6 +151,7 @@ class MemoryCreateRequest(BaseRequest):
161151
mem_cube_id: str | None = Field(None, description="Cube ID")
162152
source: str | None = Field(None, description="Source of the memory")
163153
user_profile: bool = Field(False, description="User profile memory")
154+
session_id: str | None = Field(None, description="Session id")
164155

165156

166157
class SearchRequest(BaseRequest):
@@ -170,6 +161,7 @@ class SearchRequest(BaseRequest):
170161
query: str = Field(..., description="Search query")
171162
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
172163
top_k: int = Field(10, description="Number of results to return")
164+
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
173165

174166

175167
class SuggestionRequest(BaseRequest):

src/memos/api/routers/product_router.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def create_memory(memory_req: MemoryCreateRequest):
204204
mem_cube_id=memory_req.mem_cube_id,
205205
source=memory_req.source,
206206
user_profile=memory_req.user_profile,
207+
session_id=memory_req.session_id,
207208
)
208209
return SimpleResponse(message="Memory created successfully")
209210

@@ -224,6 +225,7 @@ def search_memories(search_req: SearchRequest):
224225
user_id=search_req.user_id,
225226
install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
226227
top_k=search_req.top_k,
228+
session_id=search_req.session_id,
227229
)
228230
return SearchResponse(message="Search completed successfully", data=result)
229231

@@ -251,6 +253,7 @@ def generate_chat_response():
251253
history=chat_req.history,
252254
internet_search=chat_req.internet_search,
253255
moscube=chat_req.moscube,
256+
session_id=chat_req.session_id,
254257
)
255258

256259
except Exception as e:
@@ -295,6 +298,7 @@ def chat_complete(chat_req: ChatCompleteRequest):
295298
base_prompt=chat_req.base_prompt,
296299
top_k=chat_req.top_k,
297300
threshold=chat_req.threshold,
301+
session_id=chat_req.session_id,
298302
)
299303

300304
# Return the complete response

src/memos/graph_dbs/nebular.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,7 @@ def search_by_embedding(
977977
scope: str | None = None,
978978
status: str | None = None,
979979
threshold: float | None = None,
980+
search_filter: dict | None = None,
980981
**kwargs,
981982
) -> list[dict]:
982983
"""
@@ -989,6 +990,8 @@ def search_by_embedding(
989990
status (str, optional): Node status filter (e.g., 'active', 'archived').
990991
If provided, restricts results to nodes with matching status.
991992
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
993+
search_filter (dict, optional): Additional metadata filters for search results.
994+
Keys should match node properties, values are the expected values.
992995
993996
Returns:
994997
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -998,6 +1001,7 @@ def search_by_embedding(
9981001
- If scope is provided, it restricts results to nodes with matching memory_type.
9991002
- If 'status' is provided, only nodes with the matching status will be returned.
10001003
- If threshold is provided, only results with score >= threshold will be returned.
1004+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
10011005
- Typical use case: restrict to 'status = activated' to avoid
10021006
matching archived or merged nodes.
10031007
"""
@@ -1017,6 +1021,14 @@ def search_by_embedding(
10171021
else:
10181022
where_clauses.append(f'n.user_name = "{self.config.user_name}"')
10191023

1024+
# Add search_filter conditions
1025+
if search_filter:
1026+
for key, value in search_filter.items():
1027+
if isinstance(value, str):
1028+
where_clauses.append(f'n.{key} = "{value}"')
1029+
else:
1030+
where_clauses.append(f"n.{key} = {value}")
1031+
10201032
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
10211033

10221034
gql = f"""

src/memos/graph_dbs/neo4j.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def search_by_embedding(
606606
scope: str | None = None,
607607
status: str | None = None,
608608
threshold: float | None = None,
609+
search_filter: dict | None = None,
609610
**kwargs,
610611
) -> list[dict]:
611612
"""
@@ -618,6 +619,8 @@ def search_by_embedding(
618619
status (str, optional): Node status filter (e.g., 'active', 'archived').
619620
If provided, restricts results to nodes with matching status.
620621
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
622+
search_filter (dict, optional): Additional metadata filters for search results.
623+
Keys should match node properties, values are the expected values.
621624
622625
Returns:
623626
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -627,6 +630,7 @@ def search_by_embedding(
627630
- If scope is provided, it restricts results to nodes with matching memory_type.
628631
- If 'status' is provided, only nodes with the matching status will be returned.
629632
- If threshold is provided, only results with score >= threshold will be returned.
633+
- If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
630634
- Typical use case: restrict to 'status = activated' to avoid
631635
matching archived or merged nodes.
632636
"""
@@ -639,6 +643,12 @@ def search_by_embedding(
639643
if not self.config.use_multi_db and self.config.user_name:
640644
where_clauses.append("node.user_name = $user_name")
641645

646+
# Add search_filter conditions
647+
if search_filter:
648+
for key, _ in search_filter.items():
649+
param_name = f"filter_{key}"
650+
where_clauses.append(f"node.{key} = ${param_name}")
651+
642652
where_clause = ""
643653
if where_clauses:
644654
where_clause = "WHERE " + " AND ".join(where_clauses)
@@ -650,7 +660,8 @@ def search_by_embedding(
650660
RETURN node.id AS id, score
651661
"""
652662

653-
parameters = {"embedding": vector, "k": top_k, "scope": scope}
663+
parameters = {"embedding": vector, "k": top_k}
664+
654665
if scope:
655666
parameters["scope"] = scope
656667
if status:
@@ -661,6 +672,12 @@ def search_by_embedding(
661672
else:
662673
parameters["user_name"] = self.config.user_name
663674

675+
# Add search_filter parameters
676+
if search_filter:
677+
for key, value in search_filter.items():
678+
param_name = f"filter_{key}"
679+
parameters[param_name] = value
680+
664681
with self.driver.session(database=self.db_name) as session:
665682
result = session.run(query, parameters)
666683
records = [{"id": record["id"], "score": record["score"]} for record in result]

src/memos/graph_dbs/neo4j_community.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def search_by_embedding(
129129
scope: str | None = None,
130130
status: str | None = None,
131131
threshold: float | None = None,
132+
search_filter: dict | None = None,
132133
**kwargs,
133134
) -> list[dict]:
134135
"""
@@ -140,6 +141,7 @@ def search_by_embedding(
140141
scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
141142
status (str, optional): Node status filter (e.g., 'activated', 'archived').
142143
threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
144+
search_filter (dict, optional): Additional metadata filters to apply.
143145
144146
Returns:
145147
list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
@@ -149,6 +151,7 @@ def search_by_embedding(
149151
- If 'scope' is provided, it restricts results to nodes with matching memory_type.
150152
- If 'status' is provided, it further filters nodes by status.
151153
- If 'threshold' is provided, only results with score >= threshold will be returned.
154+
- If 'search_filter' is provided, it applies additional metadata-based filtering.
152155
- The returned IDs can be used to fetch full node data from Neo4j if needed.
153156
"""
154157
# Build VecDB filter
@@ -163,6 +166,10 @@ def search_by_embedding(
163166
else:
164167
vec_filter["user_name"] = self.config.user_name
165168

169+
# Add search_filter conditions
170+
if search_filter:
171+
vec_filter.update(search_filter)
172+
166173
# Perform vector search
167174
results = self.vec_db.search(query_vector=vector, top_k=top_k, filter=vec_filter)
168175

src/memos/mem_os/core.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def search(
547547
mode: Literal["fast", "fine"] = "fast",
548548
internet_search: bool = False,
549549
moscube: bool = False,
550+
session_id: str | None = None,
550551
**kwargs,
551552
) -> MOSSearchResult:
552553
"""
@@ -563,6 +564,7 @@ def search(
563564
MemoryResult: A dictionary containing the search results.
564565
"""
565566
target_user_id = user_id if user_id is not None else self.user_id
567+
566568
self._validate_user_exists(target_user_id)
567569
# Get all cubes accessible by the target user
568570
accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
@@ -575,6 +577,11 @@ def search(
575577
self._register_chat_history(target_user_id)
576578
chat_history = self.chat_history_manager[target_user_id]
577579

580+
# Create search filter if session_id is provided
581+
search_filter = None
582+
if session_id is not None:
583+
search_filter = {"session_id": session_id}
584+
578585
result: MOSSearchResult = {
579586
"text_mem": [],
580587
"act_mem": [],
@@ -602,10 +609,11 @@ def search(
602609
manual_close_internet=not internet_search,
603610
info={
604611
"user_id": target_user_id,
605-
"session_id": self.session_id,
612+
"session_id": session_id if session_id is not None else self.session_id,
606613
"chat_history": chat_history.chat_history,
607614
},
608615
moscube=moscube,
616+
search_filter=search_filter,
609617
)
610618
result["text_mem"].append({"cube_id": mem_cube_id, "memories": memories})
611619
logger.info(
@@ -624,6 +632,8 @@ def add(
624632
doc_path: str | None = None,
625633
mem_cube_id: str | None = None,
626634
user_id: str | None = None,
635+
session_id: str | None = None,
636+
**kwargs,
627637
) -> None:
628638
"""
629639
Add textual memories to a MemCube.
@@ -636,11 +646,13 @@ def add(
636646
If None, the default MemCube for the user is used.
637647
user_id (str, optional): The identifier of the user to add the memories to.
638648
If None, the default user is used.
649+
session_id (str, optional): session_id
639650
"""
640651
# user input messages
641652
assert (messages is not None) or (memory_content is not None) or (doc_path is not None), (
642653
"messages_or_doc_path or memory_content or doc_path must be provided."
643654
)
655+
self.session_id = session_id
644656
target_user_id = user_id if user_id is not None else self.user_id
645657
if mem_cube_id is None:
646658
# Try to find a default cube for the user

src/memos/mem_os/product.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ def chat(
928928
moscube: bool = False,
929929
top_k: int = 10,
930930
threshold: float = 0.5,
931+
session_id: str | None = None,
931932
) -> str:
932933
"""
933934
Chat with LLM with memory references and complete response.
@@ -942,6 +943,7 @@ def chat(
942943
mode="fine",
943944
internet_search=internet_search,
944945
moscube=moscube,
946+
session_id=session_id,
945947
)["text_mem"]
946948

947949
memories_list = []
@@ -986,6 +988,7 @@ def chat_with_references(
986988
top_k: int = 20,
987989
internet_search: bool = False,
988990
moscube: bool = False,
991+
session_id: str | None = None,
989992
) -> Generator[str, None, None]:
990993
"""
991994
Chat with LLM with memory references and streaming output.
@@ -1012,6 +1015,7 @@ def chat_with_references(
10121015
mode="fine",
10131016
internet_search=internet_search,
10141017
moscube=moscube,
1018+
session_id=session_id,
10151019
)["text_mem"]
10161020

10171021
yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
@@ -1300,6 +1304,7 @@ def search(
13001304
install_cube_ids: list[str] | None = None,
13011305
top_k: int = 10,
13021306
mode: Literal["fast", "fine"] = "fast",
1307+
session_id: str | None = None,
13031308
):
13041309
"""Search memories for a specific user."""
13051310

@@ -1310,7 +1315,9 @@ def search(
13101315
logger.info(
13111316
f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
13121317
)
1313-
search_result = super().search(query, user_id, install_cube_ids, top_k, mode=mode)
1318+
search_result = super().search(
1319+
query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id
1320+
)
13141321
search_time_end = time.time()
13151322
logger.info(
13161323
f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}"
@@ -1346,13 +1353,16 @@ def add(
13461353
mem_cube_id: str | None = None,
13471354
source: str | None = None,
13481355
user_profile: bool = False,
1356+
session_id: str | None = None,
13491357
):
13501358
"""Add memory for a specific user."""
13511359

13521360
# Load user cubes if not already loaded
13531361
self._load_user_cubes(user_id, self.default_cube_config)
13541362

1355-
result = super().add(messages, memory_content, doc_path, mem_cube_id, user_id)
1363+
result = super().add(
1364+
messages, memory_content, doc_path, mem_cube_id, user_id, session_id=session_id
1365+
)
13561366
if user_profile:
13571367
try:
13581368
user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0]

0 commit comments

Comments
 (0)