Skip to content

Commit fcdb21c

Browse files
committed
fix bugs in text mem with neo4j backend, and set huggingface backend to be singleton
1 parent 391b422 commit fcdb21c

File tree

8 files changed

+45
-23
lines changed

8 files changed

+45
-23
lines changed

evaluation/scripts/temporal_locomo/locomo_processor.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,8 @@ def process_user(self, conv_id, locomo_df, frame, version, top_k=20):
100100

101101
oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url)
102102

103-
self.pre_context_cache[conv_id] = None
103+
with self.stats_lock:
104+
self.pre_context_cache[conv_id] = None
104105

105106
def process_qa(qa):
106107
try:
@@ -119,12 +120,14 @@ def process_qa(qa):
119120
context = ""
120121

121122
# ==== Context Answerability Analysis (for memos_scheduler only) ====
122-
123+
can_answer = False
124+
can_answer_duration_ms = 0.0
123125
if self.pre_context_cache[conv_id] is not None:
126+
can_answer_start = time()
124127
can_answer = self.analyze_context_answerability(
125128
self.pre_context_cache[conv_id], query, oai_client
126129
)
127-
130+
can_answer_duration_ms = (time() - can_answer_start) * 1000
128131
# Update statistics
129132
with self.stats_lock:
130133
self.stats[self.frame][self.version]["memory_stats"]["total_queries"] += 1
@@ -151,8 +154,8 @@ def process_qa(qa):
151154
hit_rate
152155
)
153156
self.save_stats()
154-
155-
self.pre_context_cache[conv_id] = context
157+
with self.stats_lock:
158+
self.pre_context_cache[conv_id] = context
156159

157160
self.print_eval_info()
158161

@@ -174,6 +177,7 @@ def process_qa(qa):
174177
"search_context": context,
175178
"response_duration_ms": response_duration_ms,
176179
"search_duration_ms": search_duration_ms,
180+
"can_answer_duration_ms": can_answer_duration_ms,
177181
"can_answer": can_answer if frame == "memos_scheduler" else None,
178182
}
179183
except Exception as e:
@@ -191,16 +195,19 @@ def process_qa(qa):
191195
if result["search_context"]
192196
else "No context"
193197
)
194-
print(
195-
{
196-
"question": result["question"][:100],
197-
"answer": result["answer"][:100],
198-
"category": result["category"],
199-
"golden_answer": result["golden_answer"],
200-
"search_context": context_preview[:100],
201-
"search_duration_ms": result["search_duration_ms"],
202-
}
203-
)
198+
if "can_answer" in result:
199+
print("Print can_answer examples")
200+
print(
201+
{
202+
"question": result["question"][:100],
203+
"pre context can answer": result["can_answer"],
204+
"answer": result["answer"][:100],
205+
"category": result["category"],
206+
"golden_answer": result["golden_answer"],
207+
"search_context": context_preview[:100],
208+
"search_duration_ms": result["search_duration_ms"],
209+
}
210+
)
204211

205212
search_results[conv_id].append(
206213
{

evaluation/scripts/temporal_locomo/modules/client_manager.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,6 @@ def get_client_from_storage(
153153
# Replace the original scheduler
154154
mos.mem_scheduler = scheduler_for_eval
155155

156-
# stop mem_scheduler thread
157-
mos.mem_scheduler.stop()
158156
return mos
159157

160158
def locomo_response(self, frame, llm_client, context: str, question: str) -> str:

examples/data/config/mem_scheduler/mem_cube_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ act_mem:
3434
config:
3535
memory_filename: "activation_memory.pickle"
3636
extractor_llm:
37-
backend: "huggingface"
37+
backend: "huggingface_singleton"
3838
config:
3939
model_name_or_path: "Qwen/Qwen3-1.7B"
4040
temperature: 0.8
@@ -48,7 +48,7 @@ para_mem:
4848
config:
4949
memory_filename: "parametric_memory.adapter"
5050
extractor_llm:
51-
backend: "huggingface"
51+
backend: "huggingface_singleton"
5252
config:
5353
model_name_or_path: "Qwen/Qwen3-1.7B"
5454
temperature: 0.8

examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler_and_openai.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
user_id: "root"
22
chat_model:
3-
backend: "huggingface"
3+
backend: "huggingface_singleton"
44
config:
55
model_name_or_path: "Qwen/Qwen3-1.7B"
66
temperature: 0.1

examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
user_id: "root"
22
chat_model:
3-
backend: "huggingface"
3+
backend: "huggingface_singleton"
44
config:
55
model_name_or_path: "Qwen/Qwen3-1.7B"
66
temperature: 0.1

examples/data/config/mem_scheduler/memos_config_w_scheduler_and_openai.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
user_id: "root"
22
chat_model:
3-
backend: "huggingface"
3+
backend: "huggingface_singleton"
44
config:
55
model_name_or_path: "Qwen/Qwen3-1.7B"
66
temperature: 0.1

examples/mem_scheduler/memos_w_scheduler_for_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def init_task():
186186
)
187187

188188
# Initialization
189-
print("🔧 Initializing MOS with Enhanced Scheduler...")
189+
print("🔧 Initializing MOS with Scheduler...")
190190
mos = MOSForTestScheduler(mos_config)
191191

192192
user_id = "user_1"

src/memos/graph_dbs/neo4j.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import time
23

34
from datetime import datetime
@@ -174,6 +175,12 @@ def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
174175
n.updated_at = datetime($updated_at),
175176
n += $metadata
176177
"""
178+
179+
# serialization
180+
if metadata["sources"]:
181+
for idx in range(len(metadata["sources"])):
182+
metadata["sources"][idx] = json.dumps(metadata["sources"][idx])
183+
177184
with self.driver.session(database=self.db_name) as session:
178185
session.run(
179186
query,
@@ -1128,4 +1135,14 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
11281135
node[time_field] = node[time_field].isoformat()
11291136
node.pop("user_name", None)
11301137

1138+
# serialization
1139+
if node["sources"]:
1140+
for idx in range(len(node["sources"])):
1141+
if not (
1142+
isinstance(node["sources"][idx], str)
1143+
and node["sources"][idx][0] == "{"
1144+
and node["sources"][idx][0] == "}"
1145+
):
1146+
break
1147+
node["sources"][idx] = json.loads(node["sources"][idx])
11311148
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}

0 commit comments

Comments
 (0)