4141logger = logging .getLogger (__name__ )
4242
4343# Context variable for async execution context
44- execution_context_var = contextvars .ContextVar (' execution_context' , default = None )
44+ execution_context_var = contextvars .ContextVar (" execution_context" , default = None )
4545
4646
4747class ExecutionContext :
@@ -82,14 +82,14 @@ def _generate_execution_id():
8282def _extract_context_from_headers (headers ):
8383 """Extract execution context from request headers."""
8484 execution_id = headers .get (EXECUTION_ID_REQUEST_HEADER )
85-
85+
8686 # Try to get span ID from trace context header
8787 trace_context = re .match (
8888 _TRACE_CONTEXT_REGEX_PATTERN ,
8989 headers .get (TRACE_CONTEXT_REQUEST_HEADER , "" ),
9090 )
9191 span_id = trace_context .group ("span_id" ) if trace_context else None
92-
92+
9393 return ExecutionContext (execution_id , span_id )
9494
9595
@@ -118,31 +118,33 @@ async def __call__(self, scope, receive, send):
118118 trace_context_header = b"x-cloud-trace-context"
119119 execution_id = None
120120 trace_context = None
121-
121+
122122 for name , value in scope .get ("headers" , []):
123123 if name .lower () == execution_id_header :
124124 execution_id = value .decode ("latin-1" )
125125 elif name .lower () == trace_context_header :
126126 trace_context = value .decode ("latin-1" )
127-
127+
128128 if not execution_id :
129129 execution_id = _generate_execution_id ()
130130 # Add the execution ID to headers
131131 new_headers = list (scope .get ("headers" , []))
132- new_headers .append ((execution_id_header , execution_id .encode ("latin-1" )))
132+ new_headers .append (
133+ (execution_id_header , execution_id .encode ("latin-1" ))
134+ )
133135 scope ["headers" ] = new_headers
134-
136+
135137 # Store execution context in ASGI scope for recovery in case of context loss
136138 # Parse trace context to extract span ID
137139 span_id = None
138140 if trace_context :
139141 trace_match = re .match (_TRACE_CONTEXT_REGEX_PATTERN , trace_context )
140142 if trace_match :
141143 span_id = trace_match .group ("span_id" )
142-
144+
143145 # Store in scope for potential recovery
144146 scope ["execution_context" ] = ExecutionContext (execution_id , span_id )
145-
147+
146148 await self .app (scope , receive , send ) # pragma: no cover
147149
148150
@@ -167,7 +169,7 @@ def wrapper(*args, **kwargs):
167169
168170 with stderr_redirect , stdout_redirect :
169171 result = view_function (* args , ** kwargs )
170-
172+
171173 # Context cleanup happens automatically via Flask's request context
172174 # No need to manually clean up flask.g
173175 return result
@@ -195,38 +197,38 @@ def decorator(view_function):
195197 async def async_wrapper (request , * args , ** kwargs ):
196198 # Extract execution context from headers
197199 context = _extract_context_from_headers (request .headers )
198-
200+
199201 # Set context using contextvars
200202 token = execution_context_var .set (context )
201-
203+
202204 with stderr_redirect , stdout_redirect :
203205 # Handle both sync and async functions
204206 if inspect .iscoroutinefunction (view_function ):
205207 result = await view_function (request , * args , ** kwargs )
206208 else :
207209 result = view_function (request , * args , ** kwargs ) # pragma: no cover
208-
210+
209211 # Only reset context on successful completion
210212 # On exception, leave context available for exception handlers
211213 execution_context_var .reset (token )
212214 return result
213-
215+
214216 @functools .wraps (view_function )
215217 def sync_wrapper (request , * args , ** kwargs ): # pragma: no cover
216218 # For sync functions, we still need to set up the context
217219 context = _extract_context_from_headers (request .headers )
218-
220+
219221 # Set context using contextvars
220222 token = execution_context_var .set (context )
221-
223+
222224 with stderr_redirect , stdout_redirect :
223225 result = view_function (request , * args , ** kwargs )
224-
226+
225227 # Only reset context on successful completion
226228 # On exception, leave context available for exception handlers
227229 execution_context_var .reset (token )
228230 return result
229-
231+
230232 # Return appropriate wrapper based on whether the function is async
231233 if inspect .iscoroutinefunction (view_function ):
232234 return async_wrapper
0 commit comments