Skip to content

Commit 2cdb0f9

Browse files
committed
Merge branch 'feat/reranker' of github.com:CaralHsi/MemOSRealPublic into feat/reranker
2 parents 19e72e4 + 6ea943d commit 2cdb0f9

File tree

4 files changed

+252
-57
lines changed

4 files changed

+252
-57
lines changed

src/memos/llms/vllm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def _generate_with_api_client(self, messages: list[MessageDict]) -> str:
105105
"temperature": float(getattr(self.config, "temperature", 0.8)),
106106
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
107107
"top_p": float(getattr(self.config, "top_p", 0.9)),
108+
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
108109
}
109110

110111
response = self.client.chat.completions.create(**completion_kwargs)
@@ -142,6 +143,7 @@ def generate_stream(self, messages: list[MessageDict]):
142143
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
143144
"top_p": float(getattr(self.config, "top_p", 0.9)),
144145
"stream": True, # Enable streaming
146+
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
145147
}
146148

147149
stream = self.client.chat.completions.create(**completion_kwargs)

src/memos/mem_os/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,30 @@ def dump(
960960
self.mem_cubes[mem_cube_id].dump(dump_dir)
961961
logger.info(f"MemCube {mem_cube_id} dumped to {dump_dir}")
962962

963+
def load(
964+
self,
965+
load_dir: str,
966+
user_id: str | None = None,
967+
mem_cube_id: str | None = None,
968+
memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
969+
) -> None:
970+
"""Dump the MemCube to a dictionary.
971+
Args:
972+
load_dir (str): The directory to load the MemCube from.
973+
user_id (str, optional): The identifier of the user to load the MemCube from.
974+
If None, the default user is used.
975+
mem_cube_id (str, optional): The identifier of the MemCube to load.
976+
If None, the default MemCube for the user is used.
977+
"""
978+
target_user_id = user_id if user_id is not None else self.user_id
979+
accessible_cubes = self.user_manager.get_user_cubes(target_user_id)
980+
if not mem_cube_id:
981+
mem_cube_id = accessible_cubes[0].cube_id
982+
if mem_cube_id not in self.mem_cubes:
983+
raise ValueError(f"MemCube with ID {mem_cube_id} does not exist. please regiester")
984+
self.mem_cubes[mem_cube_id].load(load_dir, memory_types=memory_types)
985+
logger.info(f"MemCube {mem_cube_id} loaded from {load_dir}")
986+
963987
def get_user_info(self) -> dict[str, Any]:
964988
"""Get current user information including accessible cubes.
965989

src/memos/mem_os/product.py

Lines changed: 180 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import json
23
import os
34
import random
5+
import threading
46
import time
57

68
from collections.abc import Generator
@@ -522,6 +524,174 @@ def _send_message_to_scheduler(
522524
)
523525
self.mem_scheduler.submit_messages(messages=[message_item])
524526

527+
async def _post_chat_processing(
528+
self,
529+
user_id: str,
530+
cube_id: str,
531+
query: str,
532+
full_response: str,
533+
system_prompt: str,
534+
time_start: float,
535+
time_end: float,
536+
speed_improvement: float,
537+
current_messages: list,
538+
) -> None:
539+
"""
540+
Asynchronous processing of logs, notifications and memory additions
541+
"""
542+
try:
543+
logger.info(
544+
f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
545+
)
546+
logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
547+
548+
clean_response, extracted_references = self._extract_references_from_response(
549+
full_response
550+
)
551+
logger.info(f"Extracted {len(extracted_references)} references from response")
552+
553+
# Send chat report notifications asynchronously
554+
if self.online_bot:
555+
try:
556+
from memos.memos_tools.notification_utils import (
557+
send_online_bot_notification_async,
558+
)
559+
560+
# 准备通知数据
561+
chat_data = {
562+
"query": query,
563+
"user_id": user_id,
564+
"cube_id": cube_id,
565+
"system_prompt": system_prompt,
566+
"full_response": full_response,
567+
}
568+
569+
system_data = {
570+
"references": extracted_references,
571+
"time_start": time_start,
572+
"time_end": time_end,
573+
"speed_improvement": speed_improvement,
574+
}
575+
576+
emoji_config = {"chat": "💬", "system_info": "📊"}
577+
578+
await send_online_bot_notification_async(
579+
online_bot=self.online_bot,
580+
header_name="MemOS Chat Report",
581+
sub_title_name="chat_with_references",
582+
title_color="#00956D",
583+
other_data1=chat_data,
584+
other_data2=system_data,
585+
emoji=emoji_config,
586+
)
587+
except Exception as e:
588+
logger.warning(f"Failed to send chat notification (async): {e}")
589+
590+
self._send_message_to_scheduler(
591+
user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL
592+
)
593+
594+
self.add(
595+
user_id=user_id,
596+
messages=[
597+
{
598+
"role": "user",
599+
"content": query,
600+
"chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
601+
},
602+
{
603+
"role": "assistant",
604+
"content": clean_response, # Store clean text without reference markers
605+
"chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
606+
},
607+
],
608+
mem_cube_id=cube_id,
609+
)
610+
611+
logger.info(f"Post-chat processing completed for user {user_id}")
612+
613+
except Exception as e:
614+
logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True)
615+
616+
def _start_post_chat_processing(
617+
self,
618+
user_id: str,
619+
cube_id: str,
620+
query: str,
621+
full_response: str,
622+
system_prompt: str,
623+
time_start: float,
624+
time_end: float,
625+
speed_improvement: float,
626+
current_messages: list,
627+
) -> None:
628+
"""
629+
Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
630+
"""
631+
632+
def run_async_in_thread():
633+
"""Running asynchronous tasks in a new thread"""
634+
try:
635+
loop = asyncio.new_event_loop()
636+
asyncio.set_event_loop(loop)
637+
try:
638+
loop.run_until_complete(
639+
self._post_chat_processing(
640+
user_id=user_id,
641+
cube_id=cube_id,
642+
query=query,
643+
full_response=full_response,
644+
system_prompt=system_prompt,
645+
time_start=time_start,
646+
time_end=time_end,
647+
speed_improvement=speed_improvement,
648+
current_messages=current_messages,
649+
)
650+
)
651+
finally:
652+
loop.close()
653+
except Exception as e:
654+
logger.error(
655+
f"Error in thread-based post-chat processing for user {user_id}: {e}",
656+
exc_info=True,
657+
)
658+
659+
try:
660+
# Try to get the current event loop
661+
asyncio.get_running_loop()
662+
# Create task and store reference to prevent garbage collection
663+
task = asyncio.create_task(
664+
self._post_chat_processing(
665+
user_id=user_id,
666+
cube_id=cube_id,
667+
query=query,
668+
full_response=full_response,
669+
system_prompt=system_prompt,
670+
time_start=time_start,
671+
time_end=time_end,
672+
speed_improvement=speed_improvement,
673+
current_messages=current_messages,
674+
)
675+
)
676+
# Add exception handling for the background task
677+
task.add_done_callback(
678+
lambda t: logger.error(
679+
f"Error in background post-chat processing for user {user_id}: {t.exception()}",
680+
exc_info=True,
681+
)
682+
if t.exception()
683+
else None
684+
)
685+
except RuntimeError:
686+
# No event loop, run in a new thread
687+
thread = threading.Thread(
688+
target=run_async_in_thread,
689+
name=f"PostChatProcessing-{user_id}",
690+
# Set as a daemon thread to avoid blocking program exit
691+
daemon=True,
692+
)
693+
thread.start()
694+
525695
def _filter_memories_by_threshold(
526696
self, memories: list[TextualMemoryItem], threshold: float = 0.50, min_num: int = 3
527697
) -> list[TextualMemoryItem]:
@@ -895,64 +1065,17 @@ def chat_with_references(
8951065
yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n"
8961066
yield f"data: {json.dumps({'type': 'end'})}\n\n"
8971067

898-
logger.info(f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}")
899-
logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
900-
901-
clean_response, extracted_references = self._extract_references_from_response(full_response)
902-
logger.info(f"Extracted {len(extracted_references)} references from response")
903-
904-
# Send chat report if online_bot is available
905-
try:
906-
from memos.memos_tools.notification_utils import send_online_bot_notification
907-
908-
# Prepare data for online_bot
909-
chat_data = {
910-
"query": query,
911-
"user_id": user_id,
912-
"cube_id": cube_id,
913-
"system_prompt": system_prompt,
914-
"full_response": full_response,
915-
}
916-
917-
system_data = {
918-
"references": extracted_references,
919-
"time_start": time_start,
920-
"time_end": time_end,
921-
"speed_improvement": speed_improvement,
922-
}
923-
924-
emoji_config = {"chat": "💬", "system_info": "📊"}
925-
926-
send_online_bot_notification(
927-
online_bot=self.online_bot,
928-
header_name="MemOS Chat Report",
929-
sub_title_name="chat_with_references",
930-
title_color="#00956D",
931-
other_data1=chat_data,
932-
other_data2=system_data,
933-
emoji=emoji_config,
934-
)
935-
except Exception as e:
936-
logger.warning(f"Failed to send chat notification: {e}")
937-
938-
self._send_message_to_scheduler(
939-
user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL
940-
)
941-
self.add(
1068+
# Asynchronous processing of logs, notifications and memory additions
1069+
self._start_post_chat_processing(
9421070
user_id=user_id,
943-
messages=[
944-
{
945-
"role": "user",
946-
"content": query,
947-
"chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
948-
},
949-
{
950-
"role": "assistant",
951-
"content": clean_response, # Store clean text without reference markers
952-
"chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
953-
},
954-
],
955-
mem_cube_id=cube_id,
1071+
cube_id=cube_id,
1072+
query=query,
1073+
full_response=full_response,
1074+
system_prompt=system_prompt,
1075+
time_start=time_start,
1076+
time_end=time_end,
1077+
speed_improvement=speed_improvement,
1078+
current_messages=current_messages,
9561079
)
9571080

9581081
def get_all(

src/memos/memos_tools/notification_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Notification utilities for MemOS product.
33
"""
44

5+
import asyncio
56
import logging
67

78
from collections.abc import Callable
@@ -51,6 +52,51 @@ def send_online_bot_notification(
5152
logger.warning(f"Failed to send online bot notification: {e}")
5253

5354

55+
async def send_online_bot_notification_async(
56+
online_bot: Callable | None,
57+
header_name: str,
58+
sub_title_name: str,
59+
title_color: str,
60+
other_data1: dict[str, Any],
61+
other_data2: dict[str, Any],
62+
emoji: dict[str, str],
63+
) -> None:
64+
"""
65+
Send notification via online_bot asynchronously if available.
66+
67+
Args:
68+
online_bot: The online_bot function or None
69+
header_name: Header name for the report
70+
sub_title_name: Subtitle for the report
71+
title_color: Title color
72+
other_data1: First data dict
73+
other_data2: Second data dict
74+
emoji: Emoji configuration dict
75+
"""
76+
if online_bot is None:
77+
return
78+
79+
try:
80+
# Run the potentially blocking notification in a thread pool
81+
loop = asyncio.get_event_loop()
82+
await loop.run_in_executor(
83+
None,
84+
lambda: online_bot(
85+
header_name=header_name,
86+
sub_title_name=sub_title_name,
87+
title_color=title_color,
88+
other_data1=other_data1,
89+
other_data2=other_data2,
90+
emoji=emoji,
91+
),
92+
)
93+
94+
logger.info(f"Online bot notification sent successfully (async): {header_name}")
95+
96+
except Exception as e:
97+
logger.warning(f"Failed to send online bot notification (async): {e}")
98+
99+
54100
def send_error_bot_notification(
55101
error_bot: Callable | None,
56102
err: str,

0 commit comments

Comments
 (0)