Skip to content

Commit 36d0ba0

Browse files
whipser030黑布林CaralHsifridayL
authored
feat: Feedback Function (#597)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords --------- Co-authored-by: 黑布林 <[email protected]> Co-authored-by: CaralHsi <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 3311832 commit 36d0ba0

File tree

10 files changed

+644
-170
lines changed

10 files changed

+644
-170
lines changed

src/memos/api/handlers/add_handler.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
using dependency injection for better modularity and testability.
66
"""
77

8+
from pydantic import validate_call
9+
810
from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
911
from memos.api.product_models import APIADDRequest, APIFeedbackRequest, MemoryResponse
1012
from memos.memories.textual.item import (
@@ -13,6 +15,7 @@
1315
from memos.multi_mem_cube.composite_cube import CompositeCubeView
1416
from memos.multi_mem_cube.single_cube import SingleCubeView
1517
from memos.multi_mem_cube.views import MemCubeView
18+
from memos.types import MessageList
1619

1720

1821
class AddHandler(BaseHandler):
@@ -60,38 +63,45 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
6063

6164
cube_view = self._build_cube_view(add_req)
6265

66+
@validate_call
67+
def _check_messages(messages: MessageList) -> None:
68+
pass
69+
6370
if add_req.is_feedback:
64-
chat_history = add_req.chat_history
65-
messages = add_req.messages
66-
if chat_history is None:
67-
chat_history = []
68-
if messages is None:
69-
messages = []
70-
concatenate_chat = chat_history + messages
71-
72-
last_user_index = max(i for i, d in enumerate(concatenate_chat) if d["role"] == "user")
73-
feedback_content = concatenate_chat[last_user_index]["content"]
74-
feedback_history = concatenate_chat[:last_user_index]
75-
76-
feedback_req = APIFeedbackRequest(
77-
user_id=add_req.user_id,
78-
session_id=add_req.session_id,
79-
task_id=add_req.task_id,
80-
history=feedback_history,
81-
feedback_content=feedback_content,
82-
writable_cube_ids=add_req.writable_cube_ids,
83-
async_mode=add_req.async_mode,
84-
)
85-
process_record = cube_view.feedback_memories(feedback_req)
71+
try:
72+
messages = add_req.messages
73+
_check_messages(messages)
8674

87-
self.logger.info(
88-
f"[FeedbackHandler] Final feedback results count={len(process_record)}"
89-
)
75+
chat_history = add_req.chat_history if add_req.chat_history else []
76+
concatenate_chat = chat_history + messages
9077

91-
return MemoryResponse(
92-
message="Memory feedback successfully",
93-
data=[process_record],
94-
)
78+
last_user_index = max(
79+
i for i, d in enumerate(concatenate_chat) if d["role"] == "user"
80+
)
81+
feedback_content = concatenate_chat[last_user_index]["content"]
82+
feedback_history = concatenate_chat[:last_user_index]
83+
84+
feedback_req = APIFeedbackRequest(
85+
user_id=add_req.user_id,
86+
session_id=add_req.session_id,
87+
task_id=add_req.task_id,
88+
history=feedback_history,
89+
feedback_content=feedback_content,
90+
writable_cube_ids=add_req.writable_cube_ids,
91+
async_mode=add_req.async_mode,
92+
)
93+
process_record = cube_view.feedback_memories(feedback_req)
94+
95+
self.logger.info(
96+
f"[ADDFeedbackHandler] Final feedback results count={len(process_record)}"
97+
)
98+
99+
return MemoryResponse(
100+
message="Memory feedback successfully",
101+
data=[process_record],
102+
)
103+
except Exception as e:
104+
self.logger.warning(f"[ADDFeedbackHandler] Running error: {e}")
95105

96106
results = cube_view.add_memories(add_req)
97107

src/memos/api/product_models.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -642,15 +642,14 @@ class APIFeedbackRequest(BaseRequest):
642642
)
643643
feedback_content: str | None = Field(..., description="Feedback content to process")
644644
feedback_time: str | None = Field(None, description="Feedback time")
645-
# ==== Multi-cube writing ====
646645
writable_cube_ids: list[str] | None = Field(
647646
None, description="List of cube IDs user can write for multi-cube add"
648647
)
649648
async_mode: Literal["sync", "async"] = Field(
650649
"async", description="feedback mode: sync or async"
651650
)
652651
corrected_answer: bool = Field(False, description="Whether need return corrected answer")
653-
# ==== Backward compatibility ====
652+
# ==== mem_cube_id is NOT enabled====
654653
mem_cube_id: str | None = Field(
655654
None,
656655
description=(

src/memos/graph_dbs/polardb.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,6 +1455,98 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
14551455
"""Get the ordered context chain starting from a node."""
14561456
raise NotImplementedError
14571457

1458+
@timed
1459+
def seach_by_keywords(
1460+
self,
1461+
query_words: list[str],
1462+
scope: str | None = None,
1463+
status: str | None = None,
1464+
search_filter: dict | None = None,
1465+
user_name: str | None = None,
1466+
filter: dict | None = None,
1467+
knowledgebase_ids: list[str] | None = None,
1468+
tsvector_field: str = "properties_tsvector_zh",
1469+
tsquery_config: str = "jiebaqry",
1470+
**kwargs,
1471+
) -> list[dict]:
1472+
where_clauses = []
1473+
1474+
if scope:
1475+
where_clauses.append(
1476+
f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype"
1477+
)
1478+
if status:
1479+
where_clauses.append(
1480+
f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype"
1481+
)
1482+
else:
1483+
where_clauses.append(
1484+
"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype"
1485+
)
1486+
1487+
# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
1488+
user_name_conditions = self._build_user_name_and_kb_ids_conditions_sql(
1489+
user_name=user_name,
1490+
knowledgebase_ids=knowledgebase_ids,
1491+
default_user_name=self.config.user_name,
1492+
)
1493+
1494+
# Add OR condition if we have any user_name conditions
1495+
if user_name_conditions:
1496+
if len(user_name_conditions) == 1:
1497+
where_clauses.append(user_name_conditions[0])
1498+
else:
1499+
where_clauses.append(f"({' OR '.join(user_name_conditions)})")
1500+
1501+
# Add search_filter conditions
1502+
if search_filter:
1503+
for key, value in search_filter.items():
1504+
if isinstance(value, str):
1505+
where_clauses.append(
1506+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype"
1507+
)
1508+
else:
1509+
where_clauses.append(
1510+
f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype"
1511+
)
1512+
1513+
# Build filter conditions using common method
1514+
filter_conditions = self._build_filter_conditions_sql(filter)
1515+
where_clauses.extend(filter_conditions)
1516+
# Add fulltext search condition
1517+
# Convert query_text to OR query format: "word1 | word2 | word3"
1518+
tsquery_string = " | ".join(query_words)
1519+
1520+
where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)")
1521+
1522+
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
1523+
1524+
# Build fulltext search query
1525+
query = f"""
1526+
SELECT
1527+
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
1528+
agtype_object_field_text(properties, 'memory') as memory_text
1529+
FROM "{self.db_name}_graph"."Memory"
1530+
{where_clause}
1531+
"""
1532+
1533+
params = (tsquery_string,)
1534+
logger.info(f"[search_by_fulltext] query: {query}, params: {params}")
1535+
conn = self._get_connection()
1536+
try:
1537+
with conn.cursor() as cursor:
1538+
cursor.execute(query, params)
1539+
results = cursor.fetchall()
1540+
output = []
1541+
for row in results:
1542+
oldid = row[0]
1543+
id_val = str(oldid)
1544+
output.append({"id": id_val})
1545+
1546+
return output
1547+
finally:
1548+
self._return_connection(conn)
1549+
14581550
@timed
14591551
def search_by_fulltext(
14601552
self,

0 commit comments

Comments
 (0)