Skip to content

Commit b946377

Browse files
authored
feat: add context for playground and for fastapi object (#151)
* feat: update config * fix:dim * change dim * fix:change default db * fix:delay * fix:len * fix:change recently mem size * fix:dup node error * fix: remove mock_data * fix: change config * feat: reorganize code * add: add json parse for en * fix:change user_id * fix: logger info * fix: remove unsed change * feat: add topk for api * feat: add logger * fix:fix scheduler logs and * add: fix max_user instances * fix: logger for config and qa * feat: update add context * fix:mv env to docker * fix: rm ref id for response for scheduler * add: status for product
1 parent 5e60d89 commit b946377

File tree

6 files changed

+312
-29
lines changed

6 files changed

+312
-29
lines changed
File renamed without changes.

src/memos/api/context/context.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Global request context management for trace_id and request-scoped data.
3+
4+
This module provides optional trace_id functionality that can be enabled
5+
when using the API components. It uses ContextVar to ensure thread safety
6+
and request isolation.
7+
"""
8+
9+
import uuid
10+
11+
from collections.abc import Callable
12+
from contextvars import ContextVar
13+
from typing import Any
14+
15+
16+
# Global context variable for request-scoped data
17+
_request_context: ContextVar[dict[str, Any] | None] = ContextVar("request_context", default=None)
18+
19+
20+
class RequestContext:
21+
"""
22+
Request-scoped context object that holds trace_id and other request data.
23+
24+
This provides a Flask g-like object for FastAPI applications.
25+
"""
26+
27+
def __init__(self, trace_id: str | None = None):
28+
self.trace_id = trace_id or str(uuid.uuid4())
29+
self._data: dict[str, Any] = {}
30+
31+
def set(self, key: str, value: Any) -> None:
32+
"""Set a value in the context."""
33+
self._data[key] = value
34+
35+
def get(self, key: str, default: Any | None = None) -> Any:
36+
"""Get a value from the context."""
37+
return self._data.get(key, default)
38+
39+
def __setattr__(self, name: str, value: Any) -> None:
40+
if name.startswith("_") or name == "trace_id":
41+
super().__setattr__(name, value)
42+
else:
43+
if not hasattr(self, "_data"):
44+
super().__setattr__(name, value)
45+
else:
46+
self._data[name] = value
47+
48+
def __getattr__(self, name: str) -> Any:
49+
if hasattr(self, "_data") and name in self._data:
50+
return self._data[name]
51+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
52+
53+
def to_dict(self) -> dict[str, Any]:
54+
"""Convert context to dictionary."""
55+
return {"trace_id": self.trace_id, "data": self._data.copy()}
56+
57+
58+
def set_request_context(context: RequestContext) -> None:
59+
"""
60+
Set the current request context.
61+
62+
This is typically called by the API dependency injection system.
63+
"""
64+
_request_context.set(context.to_dict())
65+
66+
67+
def get_current_trace_id() -> str | None:
68+
"""
69+
Get the current request's trace_id.
70+
71+
Returns:
72+
The trace_id if available, None otherwise.
73+
"""
74+
context = _request_context.get()
75+
if context:
76+
return context.get("trace_id")
77+
return None
78+
79+
80+
def get_current_context() -> RequestContext | None:
81+
"""
82+
Get the current request context.
83+
84+
Returns:
85+
The current RequestContext if available, None otherwise.
86+
"""
87+
context_dict = _request_context.get()
88+
if context_dict:
89+
ctx = RequestContext(trace_id=context_dict.get("trace_id"))
90+
ctx._data = context_dict.get("data", {}).copy()
91+
return ctx
92+
return None
93+
94+
95+
def require_context() -> RequestContext:
96+
"""
97+
Get the current request context, raising an error if not available.
98+
99+
Returns:
100+
The current RequestContext.
101+
102+
Raises:
103+
RuntimeError: If called outside of a request context.
104+
"""
105+
context = get_current_context()
106+
if context is None:
107+
raise RuntimeError(
108+
"No request context available. This function must be called within a request handler."
109+
)
110+
return context
111+
112+
113+
# Type for trace_id getter function
114+
TraceIdGetter = Callable[[], str | None]
115+
116+
# Global variable to hold the trace_id getter function
117+
_trace_id_getter: TraceIdGetter | None = None
118+
119+
120+
def set_trace_id_getter(getter: TraceIdGetter) -> None:
121+
"""
122+
Set a custom trace_id getter function.
123+
124+
This allows the logging system to retrieve trace_id without importing
125+
API-specific modules.
126+
"""
127+
global _trace_id_getter
128+
_trace_id_getter = getter
129+
130+
131+
def get_trace_id_for_logging() -> str | None:
132+
"""
133+
Get trace_id for logging purposes.
134+
135+
This function is used by the logging system and will use either
136+
the custom getter function or fall back to the default context.
137+
"""
138+
if _trace_id_getter:
139+
try:
140+
return _trace_id_getter()
141+
except Exception:
142+
pass
143+
return get_current_trace_id()
144+
145+
146+
# Initialize the default trace_id getter
147+
set_trace_id_getter(get_current_trace_id)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import logging
2+
3+
from fastapi import Depends, Header, Request
4+
5+
from memos.api.context.context import RequestContext, set_request_context
6+
7+
8+
logger = logging.getLogger(__name__)
9+
10+
# Type alias for the RequestContext from context module
11+
G = RequestContext
12+
13+
14+
def get_trace_id_from_header(
15+
trace_id: str | None = Header(None, alias="trace-id"),
16+
x_trace_id: str | None = Header(None, alias="x-trace-id"),
17+
g_trace_id: str | None = Header(None, alias="g-trace-id"),
18+
) -> str | None:
19+
"""
20+
Extract trace_id from various possible headers.
21+
22+
Priority: g-trace-id > x-trace-id > trace-id
23+
"""
24+
return g_trace_id or x_trace_id or trace_id
25+
26+
27+
def get_request_context(
28+
request: Request, trace_id: str | None = Depends(get_trace_id_from_header)
29+
) -> RequestContext:
30+
"""
31+
Get request context object with trace_id and request metadata.
32+
33+
This function creates a RequestContext and automatically sets it
34+
in the global context for use throughout the request lifecycle.
35+
"""
36+
# Create context object
37+
ctx = RequestContext(trace_id=trace_id)
38+
39+
# Set the context globally for this request
40+
set_request_context(ctx)
41+
42+
# Log request start
43+
logger.info(f"Request started with trace_id: {ctx.trace_id}")
44+
45+
# Add request metadata to context
46+
ctx.set("method", request.method)
47+
ctx.set("path", request.url.path)
48+
ctx.set("client_ip", request.client.host if request.client else None)
49+
50+
return ctx
51+
52+
53+
def get_g_object(trace_id: str | None = Depends(get_trace_id_from_header)) -> G:
54+
"""
55+
Get Flask g-like object for the current request.
56+
57+
This creates a RequestContext and sets it globally for access
58+
throughout the request lifecycle.
59+
"""
60+
g = RequestContext(trace_id=trace_id)
61+
set_request_context(g)
62+
logger.info(f"Request g object created with trace_id: {g.trace_id}")
63+
return g
64+
65+
66+
def get_current_g() -> G | None:
67+
"""
68+
Get the current request's g object from anywhere in the application.
69+
70+
Returns:
71+
The current request's g object if available, None otherwise.
72+
"""
73+
from memos.context import get_current_context
74+
75+
return get_current_context()
76+
77+
78+
def require_g() -> G:
79+
"""
80+
Get the current request's g object, raising an error if not available.
81+
82+
Returns:
83+
The current request's g object.
84+
85+
Raises:
86+
RuntimeError: If called outside of a request context.
87+
"""
88+
from memos.context import require_context
89+
90+
return require_context()

src/memos/api/routers/product_router.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
import logging
33
import traceback
44

5-
from fastapi import APIRouter, HTTPException
5+
from datetime import datetime
6+
from typing import Annotated
7+
8+
from fastapi import APIRouter, Depends, HTTPException
69
from fastapi.responses import StreamingResponse
710

811
from memos.api.config import APIConfig
12+
from memos.api.context.dependencies import G, get_g_object
913
from memos.api.product_models import (
1014
BaseResponse,
1115
ChatRequest,
@@ -37,14 +41,14 @@ def get_mos_product_instance():
3741
global MOS_PRODUCT_INSTANCE
3842
if MOS_PRODUCT_INSTANCE is None:
3943
default_config = APIConfig.get_product_default_config()
40-
print(default_config)
44+
logger.info(f"*********init_default_mos_config********* {default_config}")
4145
from memos.configs.mem_os import MOSConfig
4246

4347
mos_config = MOSConfig(**default_config)
4448

4549
# Get default cube config from APIConfig (may be None if disabled)
4650
default_cube_config = APIConfig.get_default_cube_config()
47-
print("*********default_cube_config*********", default_cube_config)
51+
logger.info(f"*********initdefault_cube_config******** {default_cube_config}")
4852
MOS_PRODUCT_INSTANCE = MOSProduct(
4953
default_config=mos_config, default_cube_config=default_cube_config
5054
)
@@ -64,9 +68,18 @@ async def set_config(config):
6468

6569

6670
@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
67-
async def register_user(user_req: UserRegisterRequest):
71+
async def register_user(user_req: UserRegisterRequest, g: Annotated[G, Depends(get_g_object)]):
6872
"""Register a new user with configuration and default cube."""
6973
try:
74+
# Set request-related information in g object
75+
g.user_id = user_req.user_id
76+
g.action = "user_register"
77+
g.timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
78+
79+
logger.info(f"Starting user registration for user_id: {user_req.user_id}")
80+
logger.info(f"Request trace_id: {g.trace_id}")
81+
logger.info(f"Request timestamp: {g.timestamp}")
82+
7083
# Get configuration for the user
7184
user_config, default_mem_cube = APIConfig.create_user_config(
7285
user_name=user_req.user_id, user_id=user_req.user_id

0 commit comments

Comments
 (0)