Skip to content

Commit ccef651

Browse files
committed
add new feat of thread race, and add a new test case for scheduler dispatcher
1 parent b6834d3 commit ccef651

File tree

9 files changed

+646
-14
lines changed

9 files changed

+646
-14
lines changed

evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,25 @@ def process_qa(qa):
531531
json.dump(dict(search_results), fw, indent=2)
532532
print(f"Save search results {conv_id}")
533533

534+
search_durations = []
535+
for result in response_results[conv_id]:
536+
if "search_duration_ms" in result:
537+
search_durations.append(result["search_duration_ms"])
538+
539+
if search_durations:
540+
avg_search_duration = sum(search_durations) / len(search_durations)
541+
with self.stats_lock:
542+
if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]:
543+
self.stats[self.frame][self.version]["memory_stats"][
544+
"avg_search_duration_ms"
545+
] = (
546+
self.stats[self.frame][self.version]["memory_stats"][
547+
"avg_search_duration_ms"
548+
]
549+
+ avg_search_duration
550+
) / 2
551+
print(f"Average search duration: {avg_search_duration:.2f} ms")
552+
534553
# Dump stats after processing each user
535554
self.save_stats()
536555

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import random
2+
import threading
3+
import time
4+
5+
6+
class ThreadRace:
7+
def __init__(self):
8+
# Variable to store the result
9+
self.result = None
10+
# Event to mark if the race is finished
11+
self.race_finished = threading.Event()
12+
# Lock to protect the result variable
13+
self.lock = threading.Lock()
14+
# Store thread objects for termination
15+
self.threads = {}
16+
# Stop flags for each thread
17+
self.stop_flags = {}
18+
19+
def task1(self, stop_flag):
20+
"""First task function, can be modified as needed"""
21+
# Simulate random work time
22+
sleep_time = random.uniform(0.1, 2.0)
23+
24+
# Break the sleep into smaller chunks to check stop flag
25+
chunks = 20
26+
chunk_time = sleep_time / chunks
27+
28+
for _ in range(chunks):
29+
# Check if we should stop
30+
if stop_flag.is_set():
31+
return None
32+
time.sleep(chunk_time)
33+
34+
return f"Task 1 completed in: {sleep_time:.2f} seconds"
35+
36+
def task2(self, stop_flag):
37+
"""Second task function, can be modified as needed"""
38+
# Simulate random work time
39+
sleep_time = random.uniform(0.1, 2.0)
40+
41+
# Break the sleep into smaller chunks to check stop flag
42+
chunks = 20
43+
chunk_time = sleep_time / chunks
44+
45+
for _ in range(chunks):
46+
# Check if we should stop
47+
if stop_flag.is_set():
48+
return None
49+
time.sleep(chunk_time)
50+
51+
return f"Task 2 completed in: {sleep_time:.2f} seconds"
52+
53+
def worker(self, task_func, task_name):
54+
"""Worker thread function"""
55+
# Create a stop flag for this task
56+
stop_flag = threading.Event()
57+
self.stop_flags[task_name] = stop_flag
58+
59+
try:
60+
# Execute the task with stop flag
61+
result = task_func(stop_flag)
62+
63+
# If the race is already finished or we were asked to stop, return immediately
64+
if self.race_finished.is_set() or stop_flag.is_set():
65+
return None
66+
67+
# Try to set the result (if no other thread has set it yet)
68+
with self.lock:
69+
if not self.race_finished.is_set():
70+
self.result = (task_name, result)
71+
# Mark the race as finished
72+
self.race_finished.set()
73+
print(f"{task_name} won the race!")
74+
75+
# Signal other threads to stop
76+
for name, flag in self.stop_flags.items():
77+
if name != task_name:
78+
print(f"Signaling {name} to stop")
79+
flag.set()
80+
81+
return self.result
82+
83+
except Exception as e:
84+
print(f"{task_name} encountered an error: {e}")
85+
86+
return None
87+
88+
def run_race(self):
89+
"""Start the competition and return the result of the fastest thread"""
90+
# Reset state
91+
self.race_finished.clear()
92+
self.result = None
93+
self.threads.clear()
94+
self.stop_flags.clear()
95+
96+
# Create threads
97+
thread1 = threading.Thread(target=self.worker, args=(self.task1, "Thread 1"))
98+
thread2 = threading.Thread(target=self.worker, args=(self.task2, "Thread 2"))
99+
100+
# Record thread objects for later joining
101+
self.threads["Thread 1"] = thread1
102+
self.threads["Thread 2"] = thread2
103+
104+
# Start threads
105+
thread1.start()
106+
thread2.start()
107+
108+
# Wait for any thread to complete
109+
while not self.race_finished.is_set():
110+
time.sleep(0.01) # Small delay to avoid high CPU usage
111+
112+
# If all threads have ended but no result is set, there's a problem
113+
if (
114+
not thread1.is_alive()
115+
and not thread2.is_alive()
116+
and not self.race_finished.is_set()
117+
):
118+
print("All threads have ended, but there's no winner")
119+
return None
120+
121+
# Wait for all threads to end (with timeout to avoid infinite waiting)
122+
thread1.join(timeout=1.0)
123+
thread2.join(timeout=1.0)
124+
125+
# Return the result
126+
return self.result
127+
128+
129+
# Usage example
130+
if __name__ == "__main__":
131+
race = ThreadRace()
132+
result = race.run_race()
133+
print(f"Winner: {result[0] if result else None}")
134+
print(f"Result: {result[1] if result else None}")

evaluation/scripts/temporal_locomo/temporal_locomo_eval.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, args):
3333
self.locomo_evaluator = LocomoEvaluator(args=args)
3434
self.locomo_metric = LocomoMetric(args=args)
3535

36-
def run_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
36+
def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
3737
"""
3838
Run the complete evaluation pipeline including dataset conversion,
3939
data ingestion, and processing.
@@ -99,6 +99,32 @@ def run_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
9999
print(f" - Statistics: {self.stats_path}")
100100
print("=" * 80)
101101

102+
def run_inference_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
103+
"""
104+
Run the complete evaluation pipeline including dataset conversion,
105+
data ingestion, and processing.
106+
"""
107+
print("=" * 80)
108+
print("Starting TimeLocomo Evaluation Pipeline")
109+
print("=" * 80)
110+
111+
# Step 1: Check if temporal_locomo dataset exists, if not convert it
112+
temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
113+
if not temporal_locomo_file.exists():
114+
print(f"Temporal locomo dataset not found at {temporal_locomo_file}")
115+
print("Converting locomo dataset to temporal_locomo format...")
116+
self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo")
117+
print("Dataset conversion completed.")
118+
else:
119+
print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.")
120+
121+
# Step 2: Data ingestion
122+
if not skip_ingestion:
123+
print("\n" + "=" * 50)
124+
print("Step 2: Data Ingestion")
125+
print("=" * 50)
126+
self.locomo_ingestor.run_ingestion()
127+
102128
def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider):
103129
"""
104130
Compute can-answer statistics per day for each conversation using the
@@ -120,7 +146,7 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider):
120146
parser.add_argument(
121147
"--frame",
122148
type=str,
123-
default="memos_scheduler",
149+
default="memos",
124150
choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"],
125151
help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
126152
)
@@ -152,8 +178,4 @@ def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider):
152178
args = parser.parse_args()
153179

154180
evaluator = TemporalLocomoEval(args=args)
155-
evaluator.run_eval_pipeline()
156-
157-
# rule-based baselines
158-
evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=float("inf"))
159-
evaluator.compute_can_answer_count_by_pre_evidences(rounds_to_consider=1)
181+
evaluator.run_answer_hit_eval_pipeline()

src/memos/configs/mem_scheduler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
BASE_DIR,
1212
DEFAULT_ACT_MEM_DUMP_PATH,
1313
DEFAULT_CONSUME_INTERVAL_SECONDS,
14-
DEFAULT_THREAD__POOL_MAX_WORKERS,
14+
DEFAULT_THREAD_POOL_MAX_WORKERS,
1515
)
1616

1717

@@ -25,10 +25,10 @@ class BaseSchedulerConfig(BaseConfig):
2525
default=True, description="Whether to enable parallel message processing using thread pool"
2626
)
2727
thread_pool_max_workers: int = Field(
28-
default=DEFAULT_THREAD__POOL_MAX_WORKERS,
28+
default=DEFAULT_THREAD_POOL_MAX_WORKERS,
2929
gt=1,
3030
lt=20,
31-
description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD__POOL_MAX_WORKERS})",
31+
description=f"Maximum worker threads in pool (default: {DEFAULT_THREAD_POOL_MAX_WORKERS})",
3232
)
3333
consume_interval_seconds: float = Field(
3434
default=DEFAULT_CONSUME_INTERVAL_SECONDS,

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from memos.mem_scheduler.schemas.general_schemas import (
2121
DEFAULT_ACT_MEM_DUMP_PATH,
2222
DEFAULT_CONSUME_INTERVAL_SECONDS,
23-
DEFAULT_THREAD__POOL_MAX_WORKERS,
23+
DEFAULT_THREAD_POOL_MAX_WORKERS,
2424
MemCubeID,
2525
TreeTextMemory_SEARCH_METHOD,
2626
UserID,
@@ -60,7 +60,7 @@ def __init__(self, config: BaseSchedulerConfig):
6060
self.search_method = TreeTextMemory_SEARCH_METHOD
6161
self.enable_parallel_dispatch = self.config.get("enable_parallel_dispatch", False)
6262
self.thread_pool_max_workers = self.config.get(
63-
"thread_pool_max_workers", DEFAULT_THREAD__POOL_MAX_WORKERS
63+
"thread_pool_max_workers", DEFAULT_THREAD_POOL_MAX_WORKERS
6464
)
6565

6666
self.retriever: SchedulerRetriever | None = None

src/memos/mem_scheduler/general_modules/dispatcher.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import concurrent
2+
import threading
23

34
from collections import defaultdict
45
from collections.abc import Callable
6+
from typing import Any
57

68
from memos.context.context import ContextThreadPoolExecutor
79
from memos.log import get_logger
810
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
11+
from memos.mem_scheduler.general_modules.task_threads import ThreadRace
912
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
1013

1114

@@ -22,6 +25,7 @@ class SchedulerDispatcher(BaseSchedulerModule):
2225
- Batch message processing
2326
- Graceful shutdown
2427
- Bulk handler registration
28+
- Thread race competition for parallel task execution
2529
"""
2630

2731
def __init__(self, max_workers=30, enable_parallel_dispatch=False):
@@ -49,6 +53,9 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=False):
4953
# Set to track active futures for monitoring purposes
5054
self._futures = set()
5155

56+
# Thread race module for competitive task execution
57+
self.thread_race = ThreadRace()
58+
5259
def register_handler(self, label: str, handler: Callable[[list[ScheduleMessageItem]], None]):
5360
"""
5461
Register a handler function for a specific message label.
@@ -177,6 +184,22 @@ def join(self, timeout: float | None = None) -> bool:
177184

178185
return len(not_done) == 0
179186

187+
def run_competitive_tasks(
188+
self, tasks: dict[str, Callable[[threading.Event], Any]], timeout: float = 10.0
189+
) -> tuple[str, Any] | None:
190+
"""
191+
Run multiple tasks in a competitive race, returning the result of the first task to complete.
192+
193+
Args:
194+
tasks: Dictionary mapping task names to task functions that accept a stop_flag parameter
195+
timeout: Maximum time to wait for any task to complete (in seconds)
196+
197+
Returns:
198+
Tuple of (task_name, result) from the winning task, or None if no task completes
199+
"""
200+
logger.info(f"Starting competitive execution of {len(tasks)} tasks")
201+
return self.thread_race.run_race(tasks, timeout)
202+
180203
def shutdown(self) -> None:
181204
"""Gracefully shutdown the dispatcher."""
182205
self._running = False

0 commit comments

Comments
 (0)