@@ -29,9 +29,19 @@ class RequestContext:
2929 This provides a Flask g-like object for FastAPI applications.
3030 """
3131
32- def __init__ (self , trace_id : str | None = None , api_path : str | None = None ):
32+ def __init__ (
33+ self ,
34+ trace_id : str | None = None ,
35+ api_path : str | None = None ,
36+ env : str | None = None ,
37+ user_type : str | None = None ,
38+ user_name : str | None = None ,
39+ ):
3340 self .trace_id = trace_id or "trace-id"
3441 self .api_path = api_path
42+ self .env = env
43+ self .user_type = user_type
44+ self .user_name = user_name
3545 self ._data : dict [str , Any ] = {}
3646
3747 def set (self , key : str , value : Any ) -> None :
@@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any:
4353 return self ._data .get (key , default )
4454
4555 def __setattr__ (self , name : str , value : Any ) -> None :
46- if name .startswith ("_" ) or name in ("trace_id" , "api_path" ):
56+ if name .startswith ("_" ) or name in (
57+ "trace_id" ,
58+ "api_path" ,
59+ "env" ,
60+ "user_type" ,
61+ "user_name" ,
62+ ):
4763 super ().__setattr__ (name , value )
4864 else :
4965 if not hasattr (self , "_data" ):
@@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any:
5874
5975 def to_dict (self ) -> dict [str , Any ]:
6076 """Convert context to dictionary."""
61- return {"trace_id" : self .trace_id , "api_path" : self .api_path , "data" : self ._data .copy ()}
77+ return {
78+ "trace_id" : self .trace_id ,
79+ "api_path" : self .api_path ,
80+ "env" : self .env ,
81+ "user_type" : self .user_type ,
82+ "user_name" : self .user_name ,
83+ "data" : self ._data .copy (),
84+ }
6285
6386
6487def set_request_context (context : RequestContext ) -> None :
@@ -93,6 +116,36 @@ def get_current_api_path() -> str | None:
93116 return None
94117
95118
119+ def get_current_env () -> str | None :
120+ """
121+ Get the current request's env.
122+ """
123+ context = _request_context .get ()
124+ if context :
125+ return context .get ("env" )
126+ return "prod"
127+
128+
129+ def get_current_user_type () -> str | None :
130+ """
131+ Get the current request's user type.
132+ """
133+ context = _request_context .get ()
134+ if context :
135+ return context .get ("user_type" )
136+ return "opensource"
137+
138+
139+ def get_current_user_name () -> str | None :
140+ """
141+ Get the current request's user name.
142+ """
143+ context = _request_context .get ()
144+ if context :
145+ return context .get ("user_name" )
146+ return "memos"
147+
148+
96149def get_current_context () -> RequestContext | None :
97150 """
98151 Get the current request context.
@@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None:
103156 context_dict = _request_context .get ()
104157 if context_dict :
105158 ctx = RequestContext (
106- trace_id = context_dict .get ("trace_id" ), api_path = context_dict .get ("api_path" )
159+ trace_id = context_dict .get ("trace_id" ),
160+ api_path = context_dict .get ("api_path" ),
161+ env = context_dict .get ("env" ),
162+ user_type = context_dict .get ("user_type" ),
163+ user_name = context_dict .get ("user_name" ),
107164 )
108165 ctx ._data = context_dict .get ("data" , {}).copy ()
109166 return ctx
@@ -141,14 +198,21 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):
141198
142199 self .main_trace_id = get_current_trace_id ()
143200 self .main_api_path = get_current_api_path ()
201+ self .main_env = get_current_env ()
202+ self .main_user_type = get_current_user_type ()
203+ self .main_user_name = get_current_user_name ()
144204 self .main_context = get_current_context ()
145205
146206 def run (self ):
147207 # Create a new RequestContext with the main thread's trace_id
148208 if self .main_context :
149209 # Copy the context data
150210 child_context = RequestContext (
151- trace_id = self .main_trace_id , api_path = self .main_context .api_path
211+ trace_id = self .main_trace_id ,
212+ api_path = self .main_api_path ,
213+ env = self .main_env ,
214+ user_type = self .main_user_type ,
215+ user_name = self .main_user_name ,
152216 )
153217 child_context ._data = self .main_context ._data .copy ()
154218
@@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
171235 """
172236 main_trace_id = get_current_trace_id ()
173237 main_api_path = get_current_api_path ()
238+ main_env = get_current_env ()
239+ main_user_type = get_current_user_type ()
240+ main_user_name = get_current_user_name ()
174241 main_context = get_current_context ()
175242
176243 @functools .wraps (fn )
177244 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
178245 if main_context :
179246 # Create and set new context in worker thread
180- child_context = RequestContext (trace_id = main_trace_id , api_path = main_api_path )
247+ child_context = RequestContext (
248+ trace_id = main_trace_id ,
249+ api_path = main_api_path ,
250+ env = main_env ,
251+ user_type = main_user_type ,
252+ user_name = main_user_name ,
253+ )
181254 child_context ._data = main_context ._data .copy ()
182255 set_request_context (child_context )
183256
@@ -198,13 +271,22 @@ def map(
198271 """
199272 main_trace_id = get_current_trace_id ()
200273 main_api_path = get_current_api_path ()
274+ main_env = get_current_env ()
275+ main_user_type = get_current_user_type ()
276+ main_user_name = get_current_user_name ()
201277 main_context = get_current_context ()
202278
203279 @functools .wraps (fn )
204280 def wrapper (* args : Any , ** kwargs : Any ) -> Any :
205281 if main_context :
206282 # Create and set new context in worker thread
207- child_context = RequestContext (trace_id = main_trace_id , api_path = main_api_path )
283+ child_context = RequestContext (
284+ trace_id = main_trace_id ,
285+ api_path = main_api_path ,
286+ env = main_env ,
287+ user_type = main_user_type ,
288+ user_name = main_user_name ,
289+ )
208290 child_context ._data = main_context ._data .copy ()
209291 set_request_context (child_context )
210292
0 commit comments