Skip to content

Commit 9d426bb

Browse files
whipser030黑布林CaralHsifridayL
authored
fix: embedding fail need a safety way (#704)
* update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue * develop feedback process * feedback handler configuration * upgrade feedback using * add threshold * update prompt * update prompt * fix handler * add feedback scheduler * add handler change node update * add handler change node update * add handler change node update * add handler change node update * fix interface input * add chunk and ratio filter * update stopwords * fix messages queue * add seach_by_keywords_LIKE * add doc filter * add retrieve query * add retrieve queies * patch info filter * add log and make embedding safety net * add log and make embedding safety net --------- Co-authored-by: 黑布林 <[email protected]> Co-authored-by: CaralHsi <[email protected]> Co-authored-by: chunyu li <[email protected]>
1 parent 87160f3 commit 9d426bb

File tree

3 files changed

+86
-51
lines changed

3 files changed

+86
-51
lines changed

src/memos/mem_feedback/feedback.py

Lines changed: 80 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from datetime import datetime
66
from typing import TYPE_CHECKING, Any
77

8-
from tenacity import retry, stop_after_attempt, wait_exponential
8+
from tenacity import retry, stop_after_attempt, wait_random_exponential
99

10-
from memos import log
1110
from memos.configs.memory import MemFeedbackConfig
1211
from memos.context.context import ContextThreadPoolExecutor
1312
from memos.dependency import require_python_package
1413
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
1514
from memos.graph_dbs.factory import GraphStoreFactory, PolarDBGraphDB
1615
from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM
16+
from memos.log import get_logger
1717
from memos.mem_feedback.base import BaseMemFeedback
1818
from memos.mem_feedback.utils import make_mem_item, should_keep_update, split_into_chunks
1919
from memos.mem_reader.factory import MemReaderFactory
@@ -48,7 +48,7 @@
4848
"generation": {"en": FEEDBACK_ANSWER_PROMPT, "zh": FEEDBACK_ANSWER_PROMPT_ZH},
4949
}
5050

51-
logger = log.get_logger(__name__)
51+
logger = get_logger(__name__)
5252

5353

5454
class MemFeedback(BaseMemFeedback):
@@ -83,19 +83,47 @@ def __init__(self, config: MemFeedbackConfig):
8383
self.reranker = None
8484
self.DB_IDX_READY = False
8585

86+
@require_python_package(
87+
import_name="jieba",
88+
install_command="pip install jieba",
89+
install_link="https://github.com/fxsjy/jieba",
90+
)
91+
def _tokenize_chinese(self, text):
92+
"""split zh jieba"""
93+
import jieba
94+
95+
tokens = jieba.lcut(text)
96+
tokens = [token.strip() for token in tokens if token.strip()]
97+
return self.stopword_manager.filter_words(tokens)
98+
99+
@retry(stop=stop_after_attempt(4), wait=wait_random_exponential(multiplier=1, max=10))
100+
def _embed_once(self, texts):
101+
return self.embedder.embed(texts)
102+
103+
@retry(stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, min=4, max=10))
104+
def _retry_db_operation(self, operation):
105+
try:
106+
return operation()
107+
except Exception as e:
108+
logger.error(
109+
f"[Feedback Core: _retry_db_operation] DB operation failed: {e}", exc_info=True
110+
)
111+
raise
112+
86113
def _batch_embed(self, texts: list[str], embed_bs: int = 5):
87-
embed_bs = 5
88-
texts_embeddings = []
114+
results = []
115+
dim = self.embedder.config.embedding_dims
116+
89117
for i in range(0, len(texts), embed_bs):
90118
batch = texts[i : i + embed_bs]
91119
try:
92-
texts_embeddings.extend(self.embedder.embed(batch))
120+
results.extend(self._embed_once(batch))
93121
except Exception as e:
94122
logger.error(
95-
f"[Feedback Core: process_feedback_core] Embedding batch failed: {e}",
96-
exc_info=True,
123+
f"[Feedback Core: process_feedback_core] Embedding batch failed, Cover with all zeros: {len(batch)} entries: {e}"
97124
)
98-
return texts_embeddings
125+
results.extend([[0.0] * dim for _ in range(len(batch))])
126+
return results
99127

100128
def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, info: dict):
101129
"""
@@ -108,7 +136,7 @@ def _pure_add(self, user_name: str, feedback_content: str, feedback_time: str, i
108136
lambda: self.memory_manager.add(to_add_memories, user_name=user_name)
109137
)
110138
logger.info(
111-
f"[Feedback Core: _pure_add] Added {len(added_ids)} memories for user {user_name}."
139+
f"[Feedback Core: _pure_add] Pure added {len(added_ids)} memories for user {user_name}."
112140
)
113141
return {
114142
"record": {
@@ -199,7 +227,7 @@ def _single_add_operation(
199227
lambda: self.memory_manager.add([to_add_memory], user_name=user_name, mode=async_mode)
200228
)
201229

202-
logger.info(f"[Memory Feedback ADD] {added_ids[0]}")
230+
logger.info(f"[Memory Feedback ADD] memory id: {added_ids[0]}")
203231
return {"id": added_ids[0], "text": to_add_memory.memory}
204232

205233
def _single_update_operation(
@@ -305,17 +333,22 @@ def semantics_feedback(
305333

306334
if not current_memories:
307335
operations = [{"operation": "ADD"}]
336+
logger.warning(
337+
"[Feedback Core]: There was no recall of the relevant memory, so it was added directly."
338+
)
308339
else:
309340
memory_chunks = split_into_chunks(current_memories, max_tokens_per_chunk=500)
310341

311342
all_operations = []
343+
now_time = datetime.now().isoformat()
312344
with ContextThreadPoolExecutor(max_workers=10) as executor:
313345
future_to_chunk_idx = {}
314346
for chunk in memory_chunks:
315347
current_memories_str = "\n".join(
316348
[f"{item.id}: {item.memory}" for item in chunk]
317349
)
318350
prompt = template.format(
351+
now_time=now_time,
319352
current_memories=current_memories_str,
320353
new_facts=memory_item.memory,
321354
chat_history=history_str,
@@ -337,7 +370,7 @@ def semantics_feedback(
337370

338371
operations = self.standard_operations(all_operations, current_memories)
339372

340-
logger.info(f"[Feedback memory operations]: {operations!s}")
373+
logger.info(f"[Feedback Core Operations]: {operations!s}")
341374

342375
if not operations:
343376
return {"record": {"add": [], "update": []}}
@@ -453,6 +486,7 @@ def _feedback_memory(
453486
}
454487

455488
def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: list) -> bool:
489+
"""Filter the relevant memory items based on info"""
456490
if not _info and not memory.metadata.info:
457491
return True
458492

@@ -463,10 +497,10 @@ def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys:
463497
record.append(info_v == mem_v)
464498
return all(record)
465499

466-
def _retrieve(self, query: str, info=None, user_name=None):
500+
def _retrieve(self, query: str, info=None, top_k=100, user_name=None):
467501
"""Retrieve memory items"""
468502
retrieved_mems = self.searcher.search(
469-
query, info=info, user_name=user_name, topk=50, full_recall=True
503+
query, info=info, user_name=user_name, top_k=top_k, full_recall=True
470504
)
471505
retrieved_mems = [item[0] for item in retrieved_mems]
472506
return retrieved_mems
@@ -524,11 +558,19 @@ def _get_llm_response(self, prompt: str, dsl: bool = True) -> dict:
524558
else:
525559
return response_text
526560
except Exception as e:
527-
logger.error(f"[Feedback Core LLM] Exception during chat generation: {e}")
561+
logger.error(
562+
f"[Feedback Core LLM Error] Exception during chat generation: {e} | response_text: {response_text}"
563+
)
528564
response_json = None
529565
return response_json
530566

531567
def standard_operations(self, operations, current_memories):
568+
"""
569+
Regularize the operation design
570+
1. Map the id to the correct original memory id
571+
2. If there is an update, skip the memory object of add
572+
3. If the modified text is too long, skip the update
573+
"""
532574
right_ids = [item.id for item in current_memories]
533575
right_lower_map = {x.lower(): x for x in right_ids}
534576

@@ -582,9 +624,16 @@ def correct_item(data):
582624
has_update = any(item.get("operation").lower() == "update" for item in llm_operations)
583625
if has_update:
584626
filtered_items = [
627+
item for item in llm_operations if item.get("operation").lower() == "add"
628+
]
629+
update_items = [
585630
item for item in llm_operations if item.get("operation").lower() != "add"
586631
]
587-
return filtered_items
632+
if filtered_items:
633+
logger.info(
634+
f"[Feedback Core: semantics_feedback] Due to have update objects, skip add: {filtered_items}"
635+
)
636+
return update_items
588637
else:
589638
return llm_operations
590639

@@ -683,6 +732,10 @@ def process_keyword_replace(
683732
if doc_scope != "NONE":
684733
retrieved_memories = self._doc_filter(doc_scope, retrieved_memories)
685734

735+
logger.info(
736+
f"[Feedback Core: process_keyword_replace] Keywords recalled memory for user {user_name}: {len(retrieved_ids)} memories | After filtering: {len(retrieved_memories)} memories."
737+
)
738+
686739
if not retrieved_memories:
687740
return {"record": {"add": [], "update": []}}
688741

@@ -693,14 +746,14 @@ def process_keyword_replace(
693746
if original_word in old_mem.memory:
694747
mem = old_mem.model_copy(deep=True)
695748
mem.memory = mem.memory.replace(original_word, target_word)
749+
if original_word in mem.metadata.tags:
750+
mem.metadata.tags.remove(original_word)
696751
if target_word not in mem.metadata.tags:
697752
mem.metadata.tags.append(target_word)
698753
pick_index.append(i)
699754
update_memories.append(mem)
755+
update_memories_embed = self._batch_embed([mem.memory for mem in update_memories])
700756

701-
update_memories_embed = self._retry_db_operation(
702-
lambda: self._batch_embed([mem.memory for mem in update_memories])
703-
)
704757
for _i, embed in zip(range(len(update_memories)), update_memories_embed, strict=False):
705758
update_memories[_i].metadata.embedding = embed
706759

@@ -805,9 +858,7 @@ def check_validity(item):
805858
feedback_memories = []
806859

807860
corrected_infos = [item["corrected_info"] for item in valid_feedback]
808-
feedback_memories_embeddings = self._retry_db_operation(
809-
lambda: self._batch_embed(corrected_infos)
810-
)
861+
feedback_memories_embeddings = self._batch_embed(corrected_infos)
811862

812863
for item, embedding in zip(
813864
valid_feedback, feedback_memories_embeddings, strict=False
@@ -845,8 +896,10 @@ def check_validity(item):
845896
info=info,
846897
**kwargs,
847898
)
899+
add_memories = mem_record["record"]["add"]
900+
update_memories = mem_record["record"]["update"]
848901
logger.info(
849-
f"[Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback memories for user {user_name}."
902+
f"[Feedback Core: process_feedback_core] Processed {len(feedback_memories)} feedback | add {len(add_memories)} memories | update {len(update_memories)} memories for user {user_name}."
850903
)
851904
return mem_record
852905

@@ -902,42 +955,19 @@ def process_feedback(
902955
task_id = kwargs.get("task_id", "default")
903956

904957
logger.info(
905-
f"[MemFeedback process] Feedback Completed : user {user_name} | task_id {task_id} | record {record}."
958+
f"[Feedback Core MemFeedback process] Feedback Completed : user {user_name} | task_id {task_id} | record {record}."
906959
)
907960

908961
return {"answer": answer, "record": record["record"]}
909962
except concurrent.futures.TimeoutError:
910963
logger.error(
911-
f"[MemFeedback process] Timeout in sync mode for {user_name}", exc_info=True
964+
f"[Feedback Core MemFeedback process] Timeout in sync mode for {user_name}",
965+
exc_info=True,
912966
)
913967
return {"answer": "", "record": {"add": [], "update": []}}
914968
except Exception as e:
915969
logger.error(
916-
f"[MemFeedback process] Error in concurrent tasks for {user_name}: {e}",
970+
f"[Feedback Core MemFeedback process] Error in concurrent tasks for {user_name}: {e}",
917971
exc_info=True,
918972
)
919973
return {"answer": "", "record": {"add": [], "update": []}}
920-
921-
# Helper for DB operations with retry
922-
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
923-
def _retry_db_operation(self, operation):
924-
try:
925-
return operation()
926-
except Exception as e:
927-
logger.error(
928-
f"[MemFeedback: _retry_db_operation] DB operation failed: {e}", exc_info=True
929-
)
930-
raise
931-
932-
@require_python_package(
933-
import_name="jieba",
934-
install_command="pip install jieba",
935-
install_link="https://github.com/fxsjy/jieba",
936-
)
937-
def _tokenize_chinese(self, text):
938-
"""split zh jieba"""
939-
import jieba
940-
941-
tokens = jieba.lcut(text)
942-
tokens = [token.strip() for token in tokens if token.strip()]
943-
return self.stopword_manager.filter_words(tokens)

src/memos/multi_mem_cube/single_cube.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]:
185185
task_id=feedback_req.task_id,
186186
info=feedback_req.info,
187187
)
188-
self.logger.info(f"Feedback memories result: {feedback_result}")
188+
self.logger.info(f"[Feedback memories result:] {feedback_result}")
189189
return feedback_result
190190

191191
def _get_search_mode(self, mode: str) -> str:

src/memos/templates/mem_feedback_prompts.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,8 @@
441441
]
442442
}}
443443
444+
**Current time**
445+
{now_time}
444446
445447
**Current Memories**
446448
{current_memories}
@@ -581,6 +583,9 @@
581583
]
582584
}}
583585
586+
**当前时间:**
587+
{now_time}
588+
584589
**当前记忆:**
585590
{current_memories}
586591

0 commit comments

Comments
 (0)