|
11 | 11 | from sqlalchemy.engine import Engine |
12 | 12 |
|
13 | 13 | from memos.configs.mem_scheduler import AuthConfig, BaseSchedulerConfig |
14 | | -from memos.context.context import ContextThread |
| 14 | +from memos.context.context import ( |
| 15 | + ContextThread, |
| 16 | + RequestContext, |
| 17 | + get_current_context, |
| 18 | + set_request_context, |
| 19 | +) |
15 | 20 | from memos.llms.base import BaseLLM |
16 | 21 | from memos.log import get_logger |
17 | 22 | from memos.mem_cube.base import BaseMemCube |
@@ -775,35 +780,46 @@ def _message_consumer(self) -> None: |
775 | 780 | if messages: |
776 | 781 | now = time.time() |
777 | 782 | for msg in messages: |
778 | | - enqueue_ts_obj = getattr(msg, "timestamp", None) |
779 | | - enqueue_epoch = None |
780 | | - if isinstance(enqueue_ts_obj, int | float): |
781 | | - enqueue_epoch = float(enqueue_ts_obj) |
782 | | - elif hasattr(enqueue_ts_obj, "timestamp"): |
783 | | - dt = enqueue_ts_obj |
784 | | - if dt.tzinfo is None: |
785 | | - dt = dt.replace(tzinfo=timezone.utc) |
786 | | - enqueue_epoch = dt.timestamp() |
787 | | - |
788 | | - queue_wait_ms = None |
789 | | - if enqueue_epoch is not None: |
790 | | - queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 |
791 | | - |
792 | | - # Avoid pydantic field enforcement by using object.__setattr__ |
793 | | - object.__setattr__(msg, "_dequeue_ts", now) |
794 | | - emit_monitor_event( |
795 | | - "dequeue", |
796 | | - msg, |
797 | | - { |
798 | | - "enqueue_ts": to_iso(enqueue_ts_obj), |
799 | | - "dequeue_ts": datetime.fromtimestamp( |
800 | | - now, tz=timezone.utc |
801 | | - ).isoformat(), |
802 | | - "queue_wait_ms": queue_wait_ms, |
803 | | - }, |
804 | | - ) |
805 | | - |
806 | | - self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) |
| 783 | + prev_context = get_current_context() |
| 784 | + try: |
| 785 | + # Set context for this message |
| 786 | + msg_context = RequestContext( |
| 787 | + trace_id=msg.trace_id, |
| 788 | + user_name=msg.user_name, |
| 789 | + ) |
| 790 | + set_request_context(msg_context) |
| 791 | + |
| 792 | + enqueue_ts_obj = getattr(msg, "timestamp", None) |
| 793 | + enqueue_epoch = None |
| 794 | + if isinstance(enqueue_ts_obj, int | float): |
| 795 | + enqueue_epoch = float(enqueue_ts_obj) |
| 796 | + elif hasattr(enqueue_ts_obj, "timestamp"): |
| 797 | + dt = enqueue_ts_obj |
| 798 | + if dt.tzinfo is None: |
| 799 | + dt = dt.replace(tzinfo=timezone.utc) |
| 800 | + enqueue_epoch = dt.timestamp() |
| 801 | + |
| 802 | + queue_wait_ms = None |
| 803 | + if enqueue_epoch is not None: |
| 804 | + queue_wait_ms = max(0.0, now - enqueue_epoch) * 1000 |
| 805 | + |
| 806 | + # Avoid pydantic field enforcement by using object.__setattr__ |
| 807 | + object.__setattr__(msg, "_dequeue_ts", now) |
| 808 | + emit_monitor_event( |
| 809 | + "dequeue", |
| 810 | + msg, |
| 811 | + { |
| 812 | + "enqueue_ts": to_iso(enqueue_ts_obj), |
| 813 | + "dequeue_ts": datetime.fromtimestamp( |
| 814 | + now, tz=timezone.utc |
| 815 | + ).isoformat(), |
| 816 | + "queue_wait_ms": queue_wait_ms, |
| 817 | + }, |
| 818 | + ) |
| 819 | + self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label) |
| 820 | + finally: |
| 821 | + # Restore the prior context of the consumer thread |
| 822 | + set_request_context(prev_context) |
807 | 823 | try: |
808 | 824 | import contextlib |
809 | 825 |
|
|
0 commit comments