Skip to content

Commit 668a8d7

Browse files
committed
feat: update add context
1 parent 81ae2ba commit 668a8d7

File tree

5 files changed

+270
-7
lines changed

5 files changed

+270
-7
lines changed

src/memos/api/context/context.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Global request context management for trace_id and request-scoped data.
3+
4+
This module provides optional trace_id functionality that can be enabled
5+
when using the API components. It uses ContextVar to ensure thread safety
6+
and request isolation.
7+
"""
8+
9+
import uuid
10+
11+
from collections.abc import Callable
12+
from contextvars import ContextVar
13+
from typing import Any
14+
15+
16+
# Global context variable for request-scoped data
17+
_request_context: ContextVar[dict[str, Any] | None] = ContextVar("request_context", default=None)
18+
19+
20+
class RequestContext:
21+
"""
22+
Request-scoped context object that holds trace_id and other request data.
23+
24+
This provides a Flask g-like object for FastAPI applications.
25+
"""
26+
27+
def __init__(self, trace_id: str | None = None):
28+
self.trace_id = trace_id or str(uuid.uuid4())
29+
self._data: dict[str, Any] = {}
30+
31+
def set(self, key: str, value: Any) -> None:
32+
"""Set a value in the context."""
33+
self._data[key] = value
34+
35+
def get(self, key: str, default: Any | None = None) -> Any:
36+
"""Get a value from the context."""
37+
return self._data.get(key, default)
38+
39+
def __setattr__(self, name: str, value: Any) -> None:
40+
if name.startswith("_") or name == "trace_id":
41+
super().__setattr__(name, value)
42+
else:
43+
if not hasattr(self, "_data"):
44+
super().__setattr__(name, value)
45+
else:
46+
self._data[name] = value
47+
48+
def __getattr__(self, name: str) -> Any:
49+
if hasattr(self, "_data") and name in self._data:
50+
return self._data[name]
51+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
52+
53+
def to_dict(self) -> dict[str, Any]:
54+
"""Convert context to dictionary."""
55+
return {"trace_id": self.trace_id, "data": self._data.copy()}
56+
57+
58+
def set_request_context(context: RequestContext) -> None:
59+
"""
60+
Set the current request context.
61+
62+
This is typically called by the API dependency injection system.
63+
"""
64+
_request_context.set(context.to_dict())
65+
66+
67+
def get_current_trace_id() -> str | None:
68+
"""
69+
Get the current request's trace_id.
70+
71+
Returns:
72+
The trace_id if available, None otherwise.
73+
"""
74+
context = _request_context.get()
75+
if context:
76+
return context.get("trace_id")
77+
return None
78+
79+
80+
def get_current_context() -> RequestContext | None:
81+
"""
82+
Get the current request context.
83+
84+
Returns:
85+
The current RequestContext if available, None otherwise.
86+
"""
87+
context_dict = _request_context.get()
88+
if context_dict:
89+
ctx = RequestContext(trace_id=context_dict.get("trace_id"))
90+
ctx._data = context_dict.get("data", {}).copy()
91+
return ctx
92+
return None
93+
94+
95+
def require_context() -> RequestContext:
96+
"""
97+
Get the current request context, raising an error if not available.
98+
99+
Returns:
100+
The current RequestContext.
101+
102+
Raises:
103+
RuntimeError: If called outside of a request context.
104+
"""
105+
context = get_current_context()
106+
if context is None:
107+
raise RuntimeError(
108+
"No request context available. This function must be called within a request handler."
109+
)
110+
return context
111+
112+
113+
# Type for trace_id getter function
114+
TraceIdGetter = Callable[[], str | None]
115+
116+
# Global variable to hold the trace_id getter function
117+
_trace_id_getter: TraceIdGetter | None = None
118+
119+
120+
def set_trace_id_getter(getter: TraceIdGetter) -> None:
121+
"""
122+
Set a custom trace_id getter function.
123+
124+
This allows the logging system to retrieve trace_id without importing
125+
API-specific modules.
126+
"""
127+
global _trace_id_getter
128+
_trace_id_getter = getter
129+
130+
131+
def get_trace_id_for_logging() -> str | None:
132+
"""
133+
Get trace_id for logging purposes.
134+
135+
This function is used by the logging system and will use either
136+
the custom getter function or fall back to the default context.
137+
"""
138+
if _trace_id_getter:
139+
try:
140+
return _trace_id_getter()
141+
except Exception:
142+
pass
143+
return get_current_trace_id()
144+
145+
146+
# Initialize the default trace_id getter
147+
set_trace_id_getter(get_current_trace_id)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import logging
2+
3+
from fastapi import Depends, Header, Request
4+
5+
from memos.api.context.context import RequestContext, set_request_context
6+
7+
8+
logger = logging.getLogger(__name__)
9+
10+
# Type alias for the RequestContext from context module
11+
G = RequestContext
12+
13+
14+
def get_trace_id_from_header(
15+
trace_id: str | None = Header(None, alias="trace-id"),
16+
x_trace_id: str | None = Header(None, alias="x-trace-id"),
17+
g_trace_id: str | None = Header(None, alias="g-trace-id"),
18+
) -> str | None:
19+
"""
20+
Extract trace_id from various possible headers.
21+
22+
Priority: g-trace-id > x-trace-id > trace-id
23+
"""
24+
return g_trace_id or x_trace_id or trace_id
25+
26+
27+
def get_request_context(
28+
request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
29+
) -> RequestContext:
30+
"""
31+
Get request context object with trace_id and request metadata.
32+
33+
This function creates a RequestContext and automatically sets it
34+
in the global context for use throughout the request lifecycle.
35+
"""
36+
# Create context object
37+
ctx = RequestContext(trace_id=trace_id)
38+
39+
# Set the context globally for this request
40+
set_request_context(ctx)
41+
42+
# Log request start
43+
logger.info(f"Request started with trace_id: {ctx.trace_id}")
44+
45+
# Add request metadata to context
46+
ctx.set("method", request.method)
47+
ctx.set("path", request.url.path)
48+
ctx.set("client_ip", request.client.host if request.client else None)
49+
50+
return ctx
51+
52+
53+
def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
54+
"""
55+
Get Flask g-like object for the current request.
56+
57+
This creates a RequestContext and sets it globally for access
58+
throughout the request lifecycle.
59+
"""
60+
g = RequestContext(trace_id=trace_id)
61+
set_request_context(g)
62+
logger.info(f"Request g object created with trace_id: {g.trace_id}")
63+
return g
64+
65+
66+
def get_current_g() -> G | None:
67+
"""
68+
Get the current request's g object from anywhere in the application.
69+
70+
Returns:
71+
The current request's g object if available, None otherwise.
72+
"""
73+
from memos.context import get_current_context
74+
75+
return get_current_context()
76+
77+
78+
def require_g() -> G:
79+
"""
80+
Get the current request's g object, raising an error if not available.
81+
82+
Returns:
83+
The current request's g object.
84+
85+
Raises:
86+
RuntimeError: If called outside of a request context.
87+
"""
88+
from memos.context import require_context
89+
90+
return require_context()

src/memos/api/routers/product_router.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import logging
33
import traceback
44

5-
from fastapi import APIRouter, HTTPException
5+
from datetime import datetime
6+
from typing import Annotated
7+
8+
from fastapi import APIRouter, Depends, HTTPException
69
from fastapi.responses import StreamingResponse
710

811
from memos.api.config import APIConfig
12+
from memos.api.context.dependencies import G, get_g_object
913
from memos.api.product_models import (
1014
BaseResponse,
1115
ChatRequest,
@@ -64,9 +68,18 @@ async def set_config(config):
6468

6569

6670
@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
67-
async def register_user(user_req: UserRegisterRequest):
71+
async def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_object)]):
6872
"""Register a new user with configuration and default cube."""
6973
try:
74+
# Set request-related information in g object
75+
g.user_id = user_req.user_id
76+
g.action = "user_register"
77+
g.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
78+
79+
logger.info(f"Starting user registration for user_id: {user_req.user_id}")
80+
logger.info(f"Request trace_id: {g.trace_id}")
81+
logger.info(f"Request timestamp: {g.timestamp}")
82+
7083
# Get configuration for the user
7184
user_config, default_mem_cube = APIConfig.create_user_config(
7285
user_name=user_req.user_id, user_id=user_req.user_id

src/memos/mem_os/product.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,14 @@ def _send_message_to_scheduler(
498498
)
499499
self.mem_scheduler.submit_messages(messages=[message_item])
500500

501+
def _filter_memories_by_threshold(
502+
self, memories: list[TextualMemoryItem], threshold: float = 0.20
503+
) -> list[TextualMemoryItem]:
504+
"""
505+
Filter memories by threshold.
506+
"""
507+
return [memory for memory in memories if memory.metadata.relativity >= threshold]
508+
501509
def register_mem_cube(
502510
self,
503511
mem_cube_name_or_path_or_object: str | GeneralMemCube,
@@ -689,6 +697,7 @@ def chat_with_references(
689697
user_id: str,
690698
cube_id: str | None = None,
691699
history: MessageList | None = None,
700+
top_k: int = 10,
692701
) -> Generator[str, None, None]:
693702
"""
694703
Chat with LLM with memory references and streaming output.
@@ -710,10 +719,11 @@ def chat_with_references(
710719
time_start = time.time()
711720
memories_list = []
712721
memories_result = super().search(
713-
query, user_id, install_cube_ids=[cube_id] if cube_id else None, top_k=10
722+
query, user_id, install_cube_ids=[cube_id] if cube_id else None, top_k=top_k
714723
)["text_mem"]
715724
if memories_result:
716725
memories_list = memories_result[0]["memories"]
726+
memories_list = self._filter_memories_by_threshold(memories_list)
717727

718728
# Build custom system prompt with relevant memories
719729
system_prompt = self._build_system_prompt(user_id, memories_list)
@@ -766,7 +776,7 @@ def chat_with_references(
766776
# Initialize buffer for streaming
767777
buffer = ""
768778
full_response = ""
769-
779+
token_count = 0
770780
# Use tiktoken for proper token-based chunking
771781
if self.config.chat_model.backend not in ["huggingface", "vllm"]:
772782
# For non-huggingface backends, we need to collect the full response first
@@ -779,6 +789,7 @@ def chat_with_references(
779789
for chunk in response_stream:
780790
if chunk in ["<think>", "</think>"]:
781791
continue
792+
token_count += 1
782793
buffer += chunk
783794
full_response += chunk
784795

@@ -809,7 +820,8 @@ def chat_with_references(
809820

810821
yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
811822
total_time = round(float(time_end - time_start), 1)
812-
yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': '23%'}})}\n\n"
823+
speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1)
824+
yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n"
813825
logger.info(f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}")
814826
logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
815827
self._send_message_to_scheduler(
@@ -935,6 +947,7 @@ def get_subgraph(
935947
user_id: str,
936948
query: str,
937949
mem_cube_ids: list[str] | None = None,
950+
top_k: int = 20,
938951
) -> list[dict[str, Any]]:
939952
"""Get all memory items for a user.
940953
@@ -950,7 +963,7 @@ def get_subgraph(
950963
# Load user cubes if not already loaded
951964
self._load_user_cubes(user_id, self.default_cube_config)
952965
memory_list = self._get_subgraph(
953-
query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=20
966+
query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=top_k
954967
)["text_mem"]
955968
reformat_memory_list = []
956969
for memory in memory_list:

src/memos/mem_os/utils/format_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def convert_graph_to_tree_forworkmem(
533533
node_name = extract_node_name(memory)
534534
memory_key = node.get("metadata", {}).get("key", node_name)
535535
usage = node.get("metadata", {}).get("usage", [])
536-
frequency = len(usage)
536+
frequency = len(usage) if len(usage) < 100 else 100
537537
node_map[node["id"]] = {
538538
"id": node["id"],
539539
"value": memory,

0 commit comments

Comments
 (0)