Skip to content

Commit 6d1030e

Browse files
author
yuan.wang
committed
merge dev
2 parents e4eb9db + 9341861 commit 6d1030e

37 files changed

+1828
-2113
lines changed

docker/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,4 @@ watchfiles==1.1.0
158158
websockets==15.0.1
159159
xlrd==2.0.2
160160
xlsxwriter==3.2.5
161+
prometheus-client==0.23.1

evaluation/scripts/locomo/locomo_eval.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import logging
55
import os
6+
import re
67
import time
78

89
import nltk
@@ -47,6 +48,29 @@ class LLMGrade(BaseModel):
4748
llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.")
4849

4950

51+
def extract_label_json(text: str) -> str | None:
52+
"""
53+
Extracts a JSON object of the form {"label": "VALUE"} from a given text string.
54+
This function is designed to handle cases where the LLM response contains
55+
natural language alongside a final JSON snippet, ensuring robust parsing.
56+
57+
Supports both single and double quotes around the label value.
58+
Ignores surrounding whitespace and formatting.
59+
60+
Returns:
61+
The full matching JSON string (e.g., '{"label": "CORRECT"}') if found.
62+
None if no valid label JSON is found.
63+
"""
64+
# Regex pattern to match: { "label": "value" } with optional whitespace
65+
# Matches both single and double quotes, allows spaces around keys and values
66+
pattern = r'\{\s*"label"\s*:\s*["\']([^"\']*)["\']\s*\}'
67+
match = re.search(pattern, text)
68+
if match:
69+
# Return the complete matched JSON string for safe json.loads()
70+
return match.group(0)
71+
return None
72+
73+
5074
async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool:
5175
system_prompt = """
5276
You are an expert grader that determines if answers to questions match a gold standard answer
@@ -77,20 +101,23 @@ async def locomo_grader(llm_client, question: str, gold_answer: str, response: s
77101
78102
Just return the label CORRECT or WRONG in a json format with the key as "label".
79103
"""
80-
81-
response = await llm_client.chat.completions.create(
82-
model="gpt-4o-mini",
83-
messages=[
84-
{"role": "system", "content": system_prompt},
85-
{"role": "user", "content": accuracy_prompt},
86-
],
87-
temperature=0,
88-
)
89-
message_content = response.choices[0].message.content
90-
label = json.loads(message_content)["label"]
91-
parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
92-
93-
return parsed.llm_judgment.strip().lower() == "correct"
104+
try:
105+
response = await llm_client.chat.completions.create(
106+
model=os.getenv("EVAL_MODEL", "gpt-4o-mini"),
107+
messages=[
108+
{"role": "system", "content": system_prompt},
109+
{"role": "user", "content": accuracy_prompt},
110+
],
111+
temperature=0,
112+
)
113+
message_content = response.choices[0].message.content
114+
message_content = extract_label_json(text=message_content)
115+
label = json.loads(message_content)["label"]
116+
parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
117+
return parsed.llm_judgment.strip().lower() == "correct"
118+
except Exception as e:
119+
print(f"======== {e}, {response} ===========")
120+
exit()
94121

95122

96123
def calculate_rouge_scores(gold_answer, response):
@@ -284,7 +311,7 @@ async def main(frame, version="default", options=None, num_runs=1, max_workers=4
284311
with open(response_path) as file:
285312
locomo_responses = json.load(file)
286313

287-
num_users = 10
314+
num_users = 2
288315
all_grades = {}
289316

290317
total_responses_count = sum(

evaluation/scripts/utils/client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,7 @@ def search(self, query, user_id, top_k):
189189
)
190190
response = requests.request("POST", url, data=payload, headers=self.headers)
191191
assert response.status_code == 200, response.text
192-
assert json.loads(response.text)["message"] == "Search completed successfully", (
193-
response.text
194-
)
192+
assert json.loads(response.text)["message"] == "Memory searched successfully", response.text
195193
return json.loads(response.text)["data"]
196194

197195

examples/mem_scheduler/api_w_scheduler.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from time import sleep
2+
13
from memos.api.handlers.scheduler_handler import (
24
handle_scheduler_status,
35
handle_scheduler_wait,
46
)
5-
from memos.api.routers.server_router import mem_scheduler
7+
from memos.api.routers.server_router import mem_scheduler, status_tracker
68
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
79

810

@@ -26,26 +28,25 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
2628
for msg in messages:
2729
print(f" my_test_handler - {msg.item_id}: {msg.content}")
2830
user_status_running = handle_scheduler_status(
29-
user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
31+
user_id=msg.user_id, status_tracker=status_tracker
3032
)
31-
print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running)
33+
print("[Monitor] Status after submit:", user_status_running)
3234

3335

3436
# 2. Register the handler
3537
TEST_HANDLER_LABEL = "test_handler"
38+
TEST_USER_ID = "test_user"
3639
mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler})
3740

3841
# 2.1 Monitor global scheduler status before submitting tasks
39-
global_status_before = handle_scheduler_status(
40-
user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
41-
)
42+
global_status_before = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker)
4243
print("[Monitor] Global status before submit:", global_status_before)
4344

4445
# 3. Create messages
4546
messages_to_send = [
4647
ScheduleMessageItem(
4748
item_id=f"test_item_{i}",
48-
user_id="test_user",
49+
user_id=TEST_USER_ID,
4950
mem_cube_id="test_mem_cube",
5051
label=TEST_HANDLER_LABEL,
5152
content=f"This is test message {i}",
@@ -56,28 +57,28 @@ def my_test_handler(messages: list[ScheduleMessageItem]):
5657
# 5. Submit messages
5758
for mes in messages_to_send:
5859
print(f"Submitting message {mes.item_id} to the scheduler...")
59-
mem_scheduler.memos_message_queue.submit_messages([mes])
60+
mem_scheduler.submit_messages([mes])
61+
sleep(1)
6062

6163
# 5.1 Monitor status for specific mem_cube while running
6264
USER_MEM_CUBE = "test_mem_cube"
6365

6466
# 6. Wait for messages to be processed (limited to 100 checks)
65-
print("Waiting for messages to be consumed (max 100 checks)...")
66-
mem_scheduler.mem_scheduler_wait()
67+
68+
user_status_running = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker)
69+
print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running)
6770

6871
# 6.1 Wait until idle for specific mem_cube via handler
6972
wait_result = handle_scheduler_wait(
70-
user_name=USER_MEM_CUBE,
73+
user_name=TEST_USER_ID,
74+
status_tracker=status_tracker,
7175
timeout_seconds=120.0,
72-
poll_interval=0.2,
73-
mem_scheduler=mem_scheduler,
76+
poll_interval=0.5,
7477
)
7578
print(f"[Monitor] Wait result for {USER_MEM_CUBE}:", wait_result)
7679

7780
# 6.2 Monitor global scheduler status after processing
78-
global_status_after = handle_scheduler_status(
79-
user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
80-
)
81+
global_status_after = handle_scheduler_status(user_id=TEST_USER_ID, status_tracker=status_tracker)
8182
print("[Monitor] Global status after processing:", global_status_after)
8283

8384
# 7. Stop the scheduler

poetry.lock

Lines changed: 17 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ dependencies = [
4646
"scikit-learn (>=1.7.0,<2.0.0)", # Machine learning
4747
"fastmcp (>=2.10.5,<3.0.0)",
4848
"python-dateutil (>=2.9.0.post0,<3.0.0)",
49+
"prometheus-client (>=0.23.1,<0.24.0)",
4950
]
5051

5152
[project.urls]

src/memos/api/handlers/base_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from memos.log import get_logger
1111
from memos.mem_scheduler.base_scheduler import BaseScheduler
12-
from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
12+
from memos.memories.textual.tree_text_memory.retrieve.advanced_searcher import AdvancedSearcher
1313

1414

1515
logger = get_logger(__name__)
@@ -132,7 +132,7 @@ def mem_scheduler(self) -> BaseScheduler:
132132
return self.deps.mem_scheduler
133133

134134
@property
135-
def searcher(self) -> Searcher:
135+
def searcher(self) -> AdvancedSearcher:
136136
"""Get scheduler instance."""
137137
return self.deps.searcher
138138

src/memos/api/handlers/component_init.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,21 @@ def init_server() -> dict[str, Any]:
129129
"""
130130
logger.info("Initializing MemOS server components...")
131131

132+
# Initialize Redis client first as it is a core dependency for features like scheduler status tracking
133+
try:
134+
from memos.mem_scheduler.orm_modules.api_redis_model import APIRedisDBManager
135+
136+
redis_client = APIRedisDBManager.load_redis_engine_from_env()
137+
if redis_client:
138+
logger.info("Redis client initialized successfully.")
139+
else:
140+
logger.error(
141+
"Failed to initialize Redis client. Check REDIS_HOST etc. in environment variables."
142+
)
143+
except Exception as e:
144+
logger.error(f"Failed to initialize Redis client: {e}", exc_info=True)
145+
redis_client = None # Ensure redis_client exists even on failure
146+
132147
# Get default cube configuration
133148
default_cube_config = APIConfig.get_default_cube_config()
134149

@@ -272,6 +287,8 @@ def init_server() -> dict[str, Any]:
272287
tree_mem: TreeTextMemory = naive_mem_cube.text_mem
273288
searcher: Searcher = tree_mem.get_searcher(
274289
manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
290+
moscube=False,
291+
process_llm=mem_reader.llm,
275292
)
276293
logger.debug("Searcher created")
277294

@@ -286,6 +303,7 @@ def init_server() -> dict[str, Any]:
286303
process_llm=mem_reader.llm,
287304
db_engine=BaseDBManager.create_default_sqlite_engine(),
288305
mem_reader=mem_reader,
306+
redis_client=redis_client,
289307
)
290308
mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher)
291309
logger.debug("Scheduler initialized")
@@ -335,5 +353,6 @@ def init_server() -> dict[str, Any]:
335353
"text_mem": text_mem,
336354
"pref_mem": pref_mem,
337355
"online_bot": online_bot,
356+
"redis_client": redis_client,
338357
"deepsearch_agent": deepsearch_agent,
339358
}

0 commit comments

Comments
 (0)