|
1 | 1 | import concurrent |
| 2 | +import os |
2 | 3 | import threading |
3 | 4 | import time |
4 | 5 |
|
|
19 | 20 | from memos.mem_scheduler.schemas.general_schemas import ( |
20 | 21 | DEFAULT_STOP_WAIT, |
21 | 22 | ) |
22 | | -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem |
| 23 | +from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem, ScheduleMessageItem |
23 | 24 | from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem |
24 | 25 | from memos.mem_scheduler.task_schedule_modules.orchestrator import SchedulerOrchestrator |
25 | 26 | from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue |
@@ -200,6 +201,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): |
200 | 201 | if self.status_tracker: |
201 | 202 | for msg in messages: |
202 | 203 | self.status_tracker.task_completed(task_id=msg.item_id, user_id=msg.user_id) |
| 204 | + self._maybe_emit_task_completion(messages) |
203 | 205 | self.metrics.task_completed(user_id=m.user_id, task_type=m.label) |
204 | 206 |
|
205 | 207 | emit_monitor_event( |
@@ -237,6 +239,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): |
237 | 239 | self.status_tracker.task_failed( |
238 | 240 | task_id=msg.item_id, user_id=msg.user_id, error_message=str(e) |
239 | 241 | ) |
| 242 | + self._maybe_emit_task_completion(messages, error=e) |
240 | 243 | emit_monitor_event( |
241 | 244 | "finish", |
242 | 245 | m, |
@@ -284,6 +287,85 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): |
284 | 287 |
|
285 | 288 | return wrapped_handler |
286 | 289 |
|
| 290 | + def _maybe_emit_task_completion( |
| 291 | + self, messages: list[ScheduleMessageItem], error: Exception | None = None |
| 292 | + ) -> None: |
| 293 | + """If all item_ids under a business task are completed, emit a single completion log.""" |
| 294 | + if not self.submit_web_logs or not self.status_tracker: |
| 295 | + return |
| 296 | + |
| 297 | + # messages in one batch can belong to different business task_ids; check each |
| 298 | + task_ids = {getattr(msg, "task_id", None) for msg in messages} |
| 299 | + task_ids.discard(None) |
| 300 | + if not task_ids: |
| 301 | + return |
| 302 | + |
| 303 | + # Use the first message only for shared fields; mem_cube_id is same within a batch |
| 304 | + first = messages[0] |
| 305 | + user_id = first.user_id |
| 306 | + mem_cube_id = first.mem_cube_id |
| 307 | + |
| 308 | + try: |
| 309 | + is_cloud_env = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME") == "memos-memory-change" |
| 310 | + if not is_cloud_env: |
| 311 | + return |
| 312 | + |
| 313 | + for task_id in task_ids: |
| 314 | + status_data = self.status_tracker.get_task_status_by_business_id( |
| 315 | + business_task_id=task_id, user_id=user_id |
| 316 | + ) |
| 317 | + if not status_data: |
| 318 | + continue |
| 319 | + |
| 320 | + status = status_data.get("status") |
| 321 | + |
| 322 | + if status == "completed": |
| 323 | + # Only emit success log if we didn't just catch an exception locally |
| 324 | + # (Although if status is 'completed', local error shouldn't happen theoretically, |
| 325 | + # unless status update lags or is inconsistent. We trust status_tracker here.) |
| 326 | + event = ScheduleLogForWebItem( |
| 327 | + task_id=task_id, |
| 328 | + user_id=user_id, |
| 329 | + mem_cube_id=mem_cube_id, |
| 330 | + label="taskStatus", |
| 331 | + from_memory_type="status", |
| 332 | + to_memory_type="status", |
| 333 | + log_content=f"Task {task_id} completed", |
| 334 | + status="completed", |
| 335 | + ) |
| 336 | + self.submit_web_logs(event) |
| 337 | + |
| 338 | + elif status == "failed": |
| 339 | + # Construct error message |
| 340 | + error_msg = str(error) if error else None |
| 341 | + if not error_msg: |
| 342 | + # Try to get errors from status_tracker aggregation |
| 343 | + errors = status_data.get("errors", []) |
| 344 | + if errors: |
| 345 | + error_msg = "; ".join(errors) |
| 346 | + else: |
| 347 | + error_msg = "Unknown error (check system logs)" |
| 348 | + |
| 349 | + event = ScheduleLogForWebItem( |
| 350 | + task_id=task_id, |
| 351 | + user_id=user_id, |
| 352 | + mem_cube_id=mem_cube_id, |
| 353 | + label="taskStatus", |
| 354 | + from_memory_type="status", |
| 355 | + to_memory_type="status", |
| 356 | + log_content=f"Task {task_id} failed: {error_msg}", |
| 357 | + status="failed", |
| 358 | + ) |
| 359 | + self.submit_web_logs(event) |
| 360 | + except Exception: |
| 361 | + logger.warning( |
| 362 | + "Failed to emit task completion log. user_id=%s mem_cube_id=%s task_ids=%s", |
| 363 | + user_id, |
| 364 | + mem_cube_id, |
| 365 | + list(task_ids), |
| 366 | + exc_info=True, |
| 367 | + ) |
| 368 | + |
287 | 369 | def get_running_tasks( |
288 | 370 | self, filter_func: Callable[[RunningTaskItem], bool] | None = None |
289 | 371 | ) -> dict[str, RunningTaskItem]: |
|
0 commit comments