Skip to content

Commit f641b70

Browse files
author
harvey_xiang
committed
feat: add arms
2 parents 185ed93 + 7fc8c05 commit f641b70

File tree

4 files changed

+140
-23
lines changed

4 files changed

+140
-23
lines changed

src/memos/api/middleware/request_context.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
4545
# Extract or generate trace_id
4646
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()
4747

48+
env = request.headers.get("x-env")
49+
user_type = request.headers.get("x-user-type")
50+
user_name = request.headers.get("x-user-name")
4851
start_time = time.time()
4952

5053
# Create and set request context
51-
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
54+
context = RequestContext(
55+
trace_id=trace_id,
56+
api_path=request.url.path,
57+
env=env,
58+
user_type=user_type,
59+
user_name=user_name,
60+
)
5261
set_request_context(context)
5362

5463
# Log request start with parameters
@@ -64,9 +73,6 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
6473
try:
6574
response = await call_next(request)
6675
end_time = time.time()
67-
logger.info(f"response is: {response.body}")
68-
69-
# 记录请求状态
7076
if response.status_code == 200:
7177
logger.info(
7278
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"

src/memos/context/context.py

Lines changed: 89 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,19 @@ class RequestContext:
2929
This provides a Flask g-like object for FastAPI applications.
3030
"""
3131

32-
def __init__(self, trace_id: str | None = None, api_path: str | None = None):
32+
def __init__(
33+
self,
34+
trace_id: str | None = None,
35+
api_path: str | None = None,
36+
env: str | None = None,
37+
user_type: str | None = None,
38+
user_name: str | None = None,
39+
):
3340
self.trace_id = trace_id or "trace-id"
3441
self.api_path = api_path
42+
self.env = env
43+
self.user_type = user_type
44+
self.user_name = user_name
3545
self._data: dict[str, Any] = {}
3646

3747
def set(self, key: str, value: Any) -> None:
@@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any:
4353
return self._data.get(key, default)
4454

4555
def __setattr__(self, name: str, value: Any) -> None:
46-
if name.startswith("_") or name in ("trace_id", "api_path"):
56+
if name.startswith("_") or name in (
57+
"trace_id",
58+
"api_path",
59+
"env",
60+
"user_type",
61+
"user_name",
62+
):
4763
super().__setattr__(name, value)
4864
else:
4965
if not hasattr(self, "_data"):
@@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any:
5874

5975
def to_dict(self) -> dict[str, Any]:
6076
"""Convert context to dictionary."""
61-
return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()}
77+
return {
78+
"trace_id": self.trace_id,
79+
"api_path": self.api_path,
80+
"env": self.env,
81+
"user_type": self.user_type,
82+
"user_name": self.user_name,
83+
"data": self._data.copy(),
84+
}
6285

6386

6487
def set_request_context(context: RequestContext) -> None:
@@ -93,6 +116,36 @@ def get_current_api_path() -> str | None:
93116
return None
94117

95118

119+
def get_current_env() -> str | None:
120+
"""
121+
Get the current request's env.
122+
"""
123+
context = _request_context.get()
124+
if context:
125+
return context.get("env")
126+
return "prod"
127+
128+
129+
def get_current_user_type() -> str | None:
130+
"""
131+
Get the current request's user type.
132+
"""
133+
context = _request_context.get()
134+
if context:
135+
return context.get("user_type")
136+
return "opensource"
137+
138+
139+
def get_current_user_name() -> str | None:
140+
"""
141+
Get the current request's user name.
142+
"""
143+
context = _request_context.get()
144+
if context:
145+
return context.get("user_name")
146+
return "memos"
147+
148+
96149
def get_current_context() -> RequestContext | None:
97150
"""
98151
Get the current request context.
@@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None:
103156
context_dict = _request_context.get()
104157
if context_dict:
105158
ctx = RequestContext(
106-
trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path")
159+
trace_id=context_dict.get("trace_id"),
160+
api_path=context_dict.get("api_path"),
161+
env=context_dict.get("env"),
162+
user_type=context_dict.get("user_type"),
163+
user_name=context_dict.get("user_name"),
107164
)
108165
ctx._data = context_dict.get("data", {}).copy()
109166
return ctx
@@ -141,14 +198,21 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
141198

142199
self.main_trace_id = get_current_trace_id()
143200
self.main_api_path = get_current_api_path()
201+
self.main_env = get_current_env()
202+
self.main_user_type = get_current_user_type()
203+
self.main_user_name = get_current_user_name()
144204
self.main_context = get_current_context()
145205

146206
def run(self):
147207
# Create a new RequestContext with the main thread's trace_id
148208
if self.main_context:
149209
# Copy the context data
150210
child_context = RequestContext(
151-
trace_id=self.main_trace_id, api_path=self.main_context.api_path
211+
trace_id=self.main_trace_id,
212+
api_path=self.main_api_path,
213+
env=self.main_env,
214+
user_type=self.main_user_type,
215+
user_name=self.main_user_name,
152216
)
153217
child_context._data = self.main_context._data.copy()
154218

@@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
171235
"""
172236
main_trace_id = get_current_trace_id()
173237
main_api_path = get_current_api_path()
238+
main_env = get_current_env()
239+
main_user_type = get_current_user_type()
240+
main_user_name = get_current_user_name()
174241
main_context = get_current_context()
175242

176243
@functools.wraps(fn)
177244
def wrapper(*args: Any, **kwargs: Any) -> Any:
178245
if main_context:
179246
# Create and set new context in worker thread
180-
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
247+
child_context = RequestContext(
248+
trace_id=main_trace_id,
249+
api_path=main_api_path,
250+
env=main_env,
251+
user_type=main_user_type,
252+
user_name=main_user_name,
253+
)
181254
child_context._data = main_context._data.copy()
182255
set_request_context(child_context)
183256

@@ -198,13 +271,22 @@ def map(
198271
"""
199272
main_trace_id = get_current_trace_id()
200273
main_api_path = get_current_api_path()
274+
main_env = get_current_env()
275+
main_user_type = get_current_user_type()
276+
main_user_name = get_current_user_name()
201277
main_context = get_current_context()
202278

203279
@functools.wraps(fn)
204280
def wrapper(*args: Any, **kwargs: Any) -> Any:
205281
if main_context:
206282
# Create and set new context in worker thread
207-
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
283+
child_context = RequestContext(
284+
trace_id=main_trace_id,
285+
api_path=main_api_path,
286+
env=main_env,
287+
user_type=main_user_type,
288+
user_name=main_user_name,
289+
)
208290
child_context._data = main_context._data.copy()
209291
set_request_context(child_context)
210292

src/memos/log.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,13 @@
1414
from dotenv import load_dotenv
1515

1616
from memos import settings
17-
from memos.context.context import get_current_api_path, get_current_trace_id
17+
from memos.context.context import (
18+
get_current_api_path,
19+
get_current_env,
20+
get_current_trace_id,
21+
get_current_user_name,
22+
get_current_user_type,
23+
)
1824

1925

2026
# Load environment variables
@@ -34,15 +40,22 @@ def _setup_logfile() -> Path:
3440
return logfile
3541

3642

37-
class TraceIDFilter(logging.Filter):
38-
"""add trace_id to the log record"""
43+
class ContextFilter(logging.Filter):
44+
"""add context to the log record"""
3945

4046
def filter(self, record):
4147
try:
4248
trace_id = get_current_trace_id()
4349
record.trace_id = trace_id if trace_id else "trace-id"
50+
record.env = get_current_env()
51+
record.user_type = get_current_user_type()
52+
record.user_name = get_current_user_name()
53+
record.api_path = get_current_api_path()
4454
except Exception:
4555
record.trace_id = "trace-id"
56+
record.env = "prod"
57+
record.user_type = "normal"
58+
record.user_name = "unknown"
4659
return True
4760

4861

@@ -86,13 +99,24 @@ def emit(self, record):
8699
try:
87100
trace_id = get_current_trace_id() or "trace-id"
88101
api_path = get_current_api_path()
102+
env = get_current_env()
103+
user_type = get_current_user_type()
104+
user_name = get_current_user_name()
89105
if api_path is not None:
90-
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path)
106+
self._executor.submit(
107+
self._send_log_sync,
108+
record.getMessage(),
109+
trace_id,
110+
api_path,
111+
env,
112+
user_type,
113+
user_name,
114+
)
91115
except Exception as e:
92116
if not self._is_shutting_down.is_set():
93117
print(f"Error sending log: {e}")
94118

95-
def _send_log_sync(self, message, trace_id, api_path):
119+
def _send_log_sync(self, message, trace_id, api_path, env, user_type, user_name):
96120
"""Send log message synchronously in a separate thread"""
97121
try:
98122
logger_url = os.getenv("CUSTOM_LOGGER_URL")
@@ -104,6 +128,9 @@ def _send_log_sync(self, message, trace_id, api_path):
104128
"trace_id": trace_id,
105129
"action": api_path,
106130
"current_time": round(time.time(), 3),
131+
"env": env,
132+
"user_type": user_type,
133+
"user_name": user_name,
107134
}
108135

109136
# Add auth token if exists
@@ -145,26 +172,26 @@ def close(self):
145172
"disable_existing_loggers": False,
146173
"formatters": {
147174
"standard": {
148-
"format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
175+
"format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
149176
},
150177
"no_datetime": {
151-
"format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
178+
"format": "%(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
152179
},
153180
"simplified": {
154-
"format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
181+
"format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | % %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
155182
},
156183
},
157184
"filters": {
158185
"package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX},
159-
"trace_id_filter": {"()": "memos.log.TraceIDFilter"},
186+
"context_filter": {"()": "memos.log.ContextFilter"},
160187
},
161188
"handlers": {
162189
"console": {
163190
"level": "DEBUG",
164191
"class": "logging.StreamHandler",
165192
"stream": stdout,
166193
"formatter": "no_datetime",
167-
"filters": ["package_tree_filter", "trace_id_filter"],
194+
"filters": ["package_tree_filter", "context_filter"],
168195
},
169196
"file": {
170197
"level": "DEBUG",
@@ -173,7 +200,7 @@ def close(self):
173200
"maxBytes": 1024**2 * 10,
174201
"backupCount": 10,
175202
"formatter": "standard",
176-
"filters": ["trace_id_filter"],
203+
"filters": ["context_filter"],
177204
},
178205
"custom_logger": {
179206
"level": "INFO",

src/memos/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import time
22

3+
from memos import settings
34
from memos.log import get_logger
45

56

@@ -13,7 +14,8 @@ def wrapper(*args, **kwargs):
1314
start = time.perf_counter()
1415
result = func(*args, **kwargs)
1516
elapsed = time.perf_counter() - start
16-
logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s")
17+
if settings.DEBUG:
18+
logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s")
1719
return result
1820

1921
return wrapper

0 commit comments

Comments
 (0)