1515import contextlib
1616import contextvars
1717import functools
18+ import inspect
1819import io
1920import json
2021import logging
@@ -78,6 +79,20 @@ def _generate_execution_id():
7879 )
7980
8081
82+ def _extract_context_from_headers (headers ):
83+ """Extract execution context from request headers."""
84+ execution_id = headers .get (EXECUTION_ID_REQUEST_HEADER )
85+
86+ # Try to get span ID from trace context header
87+ trace_context = re .match (
88+ _TRACE_CONTEXT_REGEX_PATTERN ,
89+ headers .get (TRACE_CONTEXT_REQUEST_HEADER , "" ),
90+ )
91+ span_id = trace_context .group ("span_id" ) if trace_context else None
92+
93+ return ExecutionContext (execution_id , span_id )
94+
95+
8196# Middleware to add execution id to request header if one does not already exist
8297class WsgiMiddleware :
8398 def __init__ (self , wsgi_app ):
@@ -147,13 +162,8 @@ def set_execution_context(request, enable_id_logging=False):
147162 def decorator (view_function ):
148163 @functools .wraps (view_function )
149164 def wrapper (* args , ** kwargs ):
150- trace_context = re .match (
151- _TRACE_CONTEXT_REGEX_PATTERN ,
152- request .headers .get (TRACE_CONTEXT_REQUEST_HEADER , "" ),
153- )
154- execution_id = request .headers .get (EXECUTION_ID_REQUEST_HEADER )
155- span_id = trace_context .group ("span_id" ) if trace_context else None
156- _set_current_context (ExecutionContext (execution_id , span_id ))
165+ context = _extract_context_from_headers (request .headers )
166+ _set_current_context (context )
157167
158168 with stderr_redirect , stdout_redirect :
159169 return view_function (* args , ** kwargs )
@@ -179,21 +189,15 @@ def set_execution_context_async(enable_id_logging=False):
179189 def decorator (view_function ):
180190 @functools .wraps (view_function )
181191 async def async_wrapper (request , * args , ** kwargs ):
182- # Extract execution ID and span ID from Starlette request
183- trace_context = re .match (
184- _TRACE_CONTEXT_REGEX_PATTERN ,
185- request .headers .get (TRACE_CONTEXT_REQUEST_HEADER , "" ),
186- )
187- execution_id = request .headers .get (EXECUTION_ID_REQUEST_HEADER )
188- span_id = trace_context .group ("span_id" ) if trace_context else None
192+ # Extract execution context from headers
193+ context = _extract_context_from_headers (request .headers )
189194
190195 # Set context using contextvars
191- token = execution_context_var .set (ExecutionContext ( execution_id , span_id ) )
196+ token = execution_context_var .set (context )
192197
193198 try :
194199 with stderr_redirect , stdout_redirect :
195200 # Handle both sync and async functions
196- import inspect
197201 if inspect .iscoroutinefunction (view_function ):
198202 return await view_function (request , * args , ** kwargs )
199203 else :
@@ -205,15 +209,10 @@ async def async_wrapper(request, *args, **kwargs):
205209 @functools .wraps (view_function )
206210 def sync_wrapper (request , * args , ** kwargs ): # pragma: no cover
207211 # For sync functions, we still need to set up the context
208- trace_context = re .match (
209- _TRACE_CONTEXT_REGEX_PATTERN ,
210- request .headers .get (TRACE_CONTEXT_REQUEST_HEADER , "" ),
211- )
212- execution_id = request .headers .get (EXECUTION_ID_REQUEST_HEADER )
213- span_id = trace_context .group ("span_id" ) if trace_context else None
212+ context = _extract_context_from_headers (request .headers )
214213
215214 # Set context using contextvars
216- token = execution_context_var .set (ExecutionContext ( execution_id , span_id ) )
215+ token = execution_context_var .set (context )
217216
218217 try :
219218 with stderr_redirect , stdout_redirect :
@@ -223,7 +222,6 @@ def sync_wrapper(request, *args, **kwargs): # pragma: no cover
223222 execution_context_var .reset (token )
224223
225224 # Return appropriate wrapper based on whether the function is async
226- import inspect
227225 if inspect .iscoroutinefunction (view_function ):
228226 return async_wrapper
229227 else :
0 commit comments