Skip to content

Commit 491a207

Browse files
feat(scheduler): Implement conditional cloud status updates
Adds functionality to send task status updates (success/failure) to RabbitMQ via , specifically for the cloud service platform. This includes: - Adding a field to . - Passing to . - Modifying to conditionally send with or (along with and ) based on the environment variable.
1 parent ffb64b7 commit 491a207

File tree

7 files changed

+91
-34
lines changed

7 files changed

+91
-34
lines changed

poetry.lock

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ dependencies = [
4646
"scikit-learn (>=1.7.0,<2.0.0)", # Machine learning
4747
"fastmcp (>=2.10.5,<3.0.0)",
4848
"python-dateutil (>=2.9.0.post0,<3.0.0)",
49-
"prometheus-client (>=0.23.1,<0.24.0)",
5049
]
5150

5251
[project.urls]
@@ -76,6 +75,7 @@ tree-mem = [
7675
mem-scheduler = [
7776
"redis (>=6.2.0,<7.0.0)", # Key-value store
7877
"pika (>=1.3.2,<2.0.0)", # RabbitMQ client
78+
"prometheus-client (>=0.23.1,<0.24.0)", # For metrics
7979
]
8080

8181
# MemUser (MySQL support)

src/memos/api/handlers/scheduler_handler.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414
from fastapi import HTTPException
1515
from fastapi.responses import StreamingResponse
1616

17-
from memos.log import get_logger
18-
1917
# Imports for new implementation
20-
from memos.api.product_models import StatusRequest, StatusResponse, StatusResponseItem
18+
from memos.api.product_models import StatusResponse, StatusResponseItem
19+
from memos.log import get_logger
2120
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
2221

2322

2423
logger = get_logger(__name__)
2524

2625

27-
def handle_scheduler_status(user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None) -> StatusResponse:
26+
def handle_scheduler_status(
27+
user_id: str, status_tracker: TaskStatusTracker, task_id: str | None = None
28+
) -> StatusResponse:
2829
"""
2930
Get scheduler running status for one or all tasks of a user.
3031
@@ -41,20 +42,24 @@ def handle_scheduler_status(user_id: str, status_tracker: TaskStatusTracker, tas
4142
Raises:
4243
HTTPException: If a specific task is not found.
4344
"""
44-
request = StatusRequest(user_id=user_id, task_id=task_id) # Construct StatusRequest internally
4545
response_data: list[StatusResponseItem] = []
4646

4747
try:
4848
if task_id:
4949
task_data = status_tracker.get_task_status(task_id, user_id)
5050
if not task_data:
51-
raise HTTPException(status_code=404, detail=f"Task {task_id} not found for user {user_id}")
51+
raise HTTPException(
52+
status_code=404, detail=f"Task {task_id} not found for user {user_id}"
53+
)
5254
response_data.append(StatusResponseItem(task_id=task_id, status=task_data["status"]))
5355
else:
5456
all_tasks = status_tracker.get_all_tasks_for_user(user_id)
5557
# The plan returns an empty list, which is good.
5658
# No need to check "if not all_tasks" explicitly before the list comprehension
57-
response_data = [StatusResponseItem(task_id=tid, status=t_data["status"]) for tid, t_data in all_tasks.items()]
59+
response_data = [
60+
StatusResponseItem(task_id=tid, status=t_data["status"])
61+
for tid, t_data in all_tasks.items()
62+
]
5863

5964
return StatusResponse(data=response_data)
6065
except HTTPException:
@@ -94,21 +99,19 @@ def handle_scheduler_wait(
9499
while time.time() - start_time < timeout_seconds:
95100
# Directly call the new, reliable status logic
96101
status_response = handle_scheduler_status(
97-
user_id=user_name,
98-
status_tracker=status_tracker
102+
user_id=user_name, status_tracker=status_tracker
99103
)
100104

101105
# System is idle if the data list is empty or no tasks are active
102106
is_idle = not status_response.data or all(
103-
task.status in ["completed", "failed", "cancelled"]
104-
for task in status_response.data
107+
task.status in ["completed", "failed", "cancelled"] for task in status_response.data
105108
)
106109

107110
if is_idle:
108111
return {
109112
"message": "idle",
110113
"data": {
111-
"running_tasks": 0, # Kept for compatibility
114+
"running_tasks": 0, # Kept for compatibility
112115
"waited_seconds": round(time.time() - start_time, 3),
113116
"timed_out": False,
114117
"user_name": user_name,
@@ -124,7 +127,7 @@ def handle_scheduler_wait(
124127
return {
125128
"message": "timeout",
126129
"data": {
127-
"running_tasks": len(active_tasks), # A more accurate count of active tasks
130+
"running_tasks": len(active_tasks), # A more accurate count of active tasks
128131
"waited_seconds": round(time.time() - start_time, 3),
129132
"timed_out": True,
130133
"user_name": user_name,
@@ -134,7 +137,9 @@ def handle_scheduler_wait(
134137
# Re-raise HTTPException directly to preserve its status code
135138
raise
136139
except Exception as err:
137-
logger.error(f"Failed while waiting for scheduler for user {user_name}: {traceback.format_exc()}")
140+
logger.error(
141+
f"Failed while waiting for scheduler for user {user_name}: {traceback.format_exc()}"
142+
)
138143
raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err
139144

140145

@@ -169,8 +174,12 @@ def event_generator():
169174
elapsed = time.time() - start_time
170175
if elapsed > timeout_seconds:
171176
# Send timeout message and break
172-
final_status = handle_scheduler_status(user_id=user_name, status_tracker=status_tracker)
173-
active_tasks = [t for t in final_status.data if t.status in ["waiting", "in_progress"]]
177+
final_status = handle_scheduler_status(
178+
user_id=user_name, status_tracker=status_tracker
179+
)
180+
active_tasks = [
181+
t for t in final_status.data if t.status in ["waiting", "in_progress"]
182+
]
174183
payload = {
175184
"user_name": user_name,
176185
"active_tasks": len(active_tasks),
@@ -184,10 +193,11 @@ def event_generator():
184193

185194
# Get status
186195
status_response = handle_scheduler_status(
187-
user_id=user_name,
188-
status_tracker=status_tracker
196+
user_id=user_name, status_tracker=status_tracker
189197
)
190-
active_tasks = [t for t in status_response.data if t.status in ["waiting", "in_progress"]]
198+
active_tasks = [
199+
t for t in status_response.data if t.status in ["waiting", "in_progress"]
200+
]
191201
num_active = len(active_tasks)
192202

193203
payload = {
@@ -200,7 +210,7 @@ def event_generator():
200210
yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
201211

202212
if num_active == 0:
203-
break # Exit loop if idle
213+
break # Exit loop if idle
204214

205215
time.sleep(poll_interval)
206216

@@ -215,4 +225,3 @@ def event_generator():
215225
yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
216226

217227
return StreamingResponse(event_generator(), media_type="text/event-stream")
218-

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def __init__(self, config: BaseSchedulerConfig):
141141
enable_parallel_dispatch=self.enable_parallel_dispatch,
142142
status_tracker=self.status_tracker,
143143
metrics=self.metrics,
144+
submit_web_logs=self._submit_web_logs,
144145
)
145146

146147
# other attributes

src/memos/mem_scheduler/schemas/message_schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class ScheduleLogForWebItem(BaseModel, DictConversionMixin):
141141
)
142142
memcube_name: str | None = Field(default=None, description="Display name for memcube")
143143
memory_len: int | None = Field(default=None, description="Count of items involved in the event")
144+
status: str | None = Field(default=None, description="Completion status of the task (e.g., 'completed', 'failed')")
144145

145146
def debug_info(self) -> dict[str, Any]:
146147
"""Return structured debug information for logging purposes."""

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
1313
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
1414
from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT
15-
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
15+
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem, ScheduleLogForWebItem
1616
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
1717
from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
1818
from memos.mem_scheduler.utils import metrics
@@ -44,6 +44,7 @@ def __init__(
4444
config=None,
4545
status_tracker: TaskStatusTracker | None = None,
4646
metrics: Any | None = None,
47+
submit_web_logs: Callable | None = None, # ADDED
4748
):
4849
super().__init__()
4950
self.config = config
@@ -95,6 +96,7 @@ def __init__(
9596

9697
self.metrics = metrics
9798
self.status_tracker = status_tracker
99+
self.submit_web_logs = submit_web_logs # ADDED
98100

99101
def on_messages_enqueued(self, msgs: list[ScheduleMessageItem]) -> None:
100102
if not msgs:
@@ -153,6 +155,18 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
153155
self.status_tracker.task_completed(task_id=task_item.item_id, user_id=task_item.user_id)
154156
self.metrics.task_completed(user_id=m.user_id, task_type=m.label)
155157

158+
is_cloud_env = os.getenv('MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME') == 'memos-memory-change'
159+
if self.submit_web_logs and is_cloud_env:
160+
status_log = ScheduleLogForWebItem(
161+
user_id=task_item.user_id,
162+
mem_cube_id=task_item.mem_cube_id,
163+
item_id=task_item.item_id,
164+
label=m.label,
165+
log_content=f"Task {task_item.item_id} completed successfully for user {task_item.user_id}.",
166+
status="completed"
167+
)
168+
self.submit_web_logs([status_log])
169+
156170
# acknowledge redis messages
157171
if self.use_redis_queue and self.memos_message_queue is not None:
158172
for msg in messages:
@@ -188,6 +202,19 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
188202
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
189203
self._completed_tasks.pop(0)
190204
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
205+
206+
is_cloud_env = os.getenv('MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME') == 'memos-memory-change'
207+
if self.submit_web_logs and is_cloud_env:
208+
status_log = ScheduleLogForWebItem(
209+
user_id=task_item.user_id,
210+
mem_cube_id=task_item.mem_cube_id,
211+
item_id=task_item.item_id,
212+
label=m.label,
213+
log_content=f"Task {task_item.item_id} failed for user {task_item.user_id} with error: {str(e)}.",
214+
status="failed",
215+
exception=str(e)
216+
)
217+
self.submit_web_logs([status_log])
191218
raise
192219

193220
return wrapped_handler

src/memos/mem_scheduler/utils/status_tracker.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# src/memos/mem_scheduler/utils/status_tracker.py
22
import json
3-
from datetime import datetime, timezone
3+
4+
from datetime import datetime, timedelta
5+
from typing import TYPE_CHECKING
46

57
from memos.dependency import require_python_package
68

79

10+
if TYPE_CHECKING:
11+
import redis
12+
13+
814
class TaskStatusTracker:
915
@require_python_package(import_name="redis", install_command="pip install redis")
1016
def __init__(self, redis_client: "redis.Redis"):
@@ -22,42 +28,54 @@ def task_submitted(self, task_id: str, user_id: str, task_type: str, mem_cube_id
2228
"submitted_at": datetime.now(timezone.utc).isoformat(),
2329
}
2430
self.redis.hset(key, task_id, json.dumps(payload))
31+
self.redis.expire(key, timedelta(days=7))
2532

2633
def task_started(self, task_id: str, user_id: str):
2734
key = self._get_key(user_id)
2835
existing_data_json = self.redis.hget(key, task_id)
2936
if not existing_data_json:
30-
# 容错处理:如果任务不存在,也创建一个
31-
payload = {"status": "in_progress", "started_at": datetime.now(timezone.utc).isoformat()}
37+
# 容错处理: 如果任务不存在, 也创建一个
38+
payload = {
39+
"status": "in_progress",
40+
"started_at": datetime.now(timezone.utc).isoformat(),
41+
}
3242
else:
3343
payload = json.loads(existing_data_json)
3444
payload["status"] = "in_progress"
3545
payload["started_at"] = datetime.now(timezone.utc).isoformat()
3646
self.redis.hset(key, task_id, json.dumps(payload))
47+
self.redis.expire(key, timedelta(days=7))
3748

3849
def task_completed(self, task_id: str, user_id: str):
3950
key = self._get_key(user_id)
4051
existing_data_json = self.redis.hget(key, task_id)
41-
if not existing_data_json: return
52+
if not existing_data_json:
53+
return
4254
payload = json.loads(existing_data_json)
4355
payload["status"] = "completed"
4456
payload["completed_at"] = datetime.now(timezone.utc).isoformat()
45-
# 设置该任务条目的过期时间例如 24 小时
46-
# 注意Redis Hash 不能为单个 field 设置 TTL这里我们可以 通过后台任务清理或在获取时判断时间戳
47-
# 简单起见我们暂时依赖一个后台清理任务
57+
# 设置该任务条目的过期时间, 例如 24 小时
58+
# 注意: Redis Hash 不能为单个 field 设置 TTL, 这里我们可以 通过后台任务清理或在获取时判断时间戳
59+
# 简单起见, 我们暂时依赖一个后台清理任务
4860
self.redis.hset(key, task_id, json.dumps(payload))
61+
self.redis.expire(key, timedelta(days=7))
4962

5063
def task_failed(self, task_id: str, user_id: str, error_message: str):
5164
key = self._get_key(user_id)
5265
existing_data_json = self.redis.hget(key, task_id)
5366
if not existing_data_json:
54-
payload = {"status": "failed", "error": error_message, "failed_at": datetime.now(timezone.utc).isoformat()}
67+
payload = {
68+
"status": "failed",
69+
"error": error_message,
70+
"failed_at": datetime.now(timezone.utc).isoformat(),
71+
}
5572
else:
5673
payload = json.loads(existing_data_json)
5774
payload["status"] = "failed"
5875
payload["error"] = error_message
5976
payload["failed_at"] = datetime.now(timezone.utc).isoformat()
6077
self.redis.hset(key, task_id, json.dumps(payload))
78+
self.redis.expire(key, timedelta(days=7))
6179

6280
def get_task_status(self, task_id: str, user_id: str) -> dict | None:
6381
key = self._get_key(user_id)

0 commit comments

Comments
 (0)