Skip to content

Commit 768019d

Browse files
TimPansinolrafeei
andauthored
Fix FastAPI Context Propagation (#420)
* Fix context propagation in fastapi. Co-authored-by: Lalleh Rafeei <[email protected]> * Add testing for context propagation errors * Format * Restore disabled test * Clean up context implementation * Format * [Mega-Linter] Apply linters fixes * Bump Tests * Expand setuptools-scm versions * Format * [Mega-Linter] Apply linters fixes * Bump Tests Co-authored-by: Lalleh Rafeei <[email protected]> Co-authored-by: TimPansino <[email protected]>
1 parent a3ef06f commit 768019d

File tree

9 files changed

+258
-280
lines changed

9 files changed

+258
-280
lines changed

newrelic/core/context.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,80 @@
1616
This module implements utilities for context propagation for tracing across threads.
1717
"""
1818

19+
import logging
20+
1921
from newrelic.common.object_wrapper import function_wrapper
2022
from newrelic.core.trace_cache import trace_cache
2123

24+
_logger = logging.getLogger(__name__)
25+
26+
2227
class ContextOf(object):
23-
def __init__(self, trace_cache_id):
28+
def __init__(self, trace=None, request=None, trace_cache_id=None):
29+
self.trace = None
2430
self.trace_cache = trace_cache()
25-
self.trace = self.trace_cache._cache.get(trace_cache_id)
2631
self.thread_id = None
2732
self.restore = None
33+
self.should_restore = False
34+
35+
# Extract trace if possible, else leave as None for safety
36+
if trace is None and request is None and trace_cache_id is None:
37+
_logger.error(
38+
"Runtime instrumentation error. Request context propagation failed. No trace or request provided. Report this issue to New Relic support.",
39+
)
40+
elif trace is not None:
41+
self.trace = trace
42+
elif trace_cache_id is not None:
43+
self.trace = self.trace_cache._cache.get(trace_cache_id, None)
44+
if self.trace is None:
45+
_logger.error(
46+
"Runtime instrumentation error. Request context propagation failed. No trace with id %s. Report this issue to New Relic support.",
47+
trace_cache_id,
48+
)
49+
elif hasattr(request, "_nr_trace") and request._nr_trace is not None:
50+
# Unpack traces from objects patched with them
51+
self.trace = request._nr_trace
52+
else:
53+
_logger.error(
54+
"Runtime instrumentation error. Request context propagation failed. No context attached to request. Report this issue to New Relic support.",
55+
)
2856

2957
def __enter__(self):
3058
if self.trace:
3159
self.thread_id = self.trace_cache.current_thread_id()
32-
self.restore = self.trace_cache._cache.get(self.thread_id)
60+
61+
# Save previous cache contents
62+
self.restore = self.trace_cache._cache.get(self.thread_id, None)
63+
self.should_restore = True
64+
65+
# Set context in trace cache
3366
self.trace_cache._cache[self.thread_id] = self.trace
67+
3468
return self
3569

3670
def __exit__(self, exc, value, tb):
37-
if self.restore:
38-
self.trace_cache._cache[self.thread_id] = self.restore
71+
if self.should_restore:
72+
if self.restore is not None:
73+
# Restore previous contents
74+
self.trace_cache._cache[self.thread_id] = self.restore
75+
else:
76+
# Remove entry from cache
77+
self.trace_cache._cache.pop(self.thread_id)
3978

4079

41-
async def context_wrapper_async(awaitable, trace_cache_id):
42-
with ContextOf(trace_cache_id):
43-
return await awaitable
44-
45-
46-
def context_wrapper(func, trace_cache_id):
80+
def context_wrapper(func, trace=None, request=None, trace_cache_id=None):
4781
@function_wrapper
4882
def _context_wrapper(wrapped, instance, args, kwargs):
49-
with ContextOf(trace_cache_id):
83+
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
5084
return wrapped(*args, **kwargs)
5185

5286
return _context_wrapper(func)
5387

5488

89+
async def context_wrapper_async(awaitable, trace=None, request=None, trace_cache_id=None):
90+
with ContextOf(trace=trace, request=request, trace_cache_id=trace_cache_id):
91+
return await awaitable
92+
93+
5594
def current_thread_id():
5695
return trace_cache().current_thread_id()

newrelic/core/trace_cache.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -275,25 +275,6 @@ def save_trace(self, trace):
275275
task = current_task(self.asyncio)
276276
trace._task = task
277277

278-
def thread_start(self, trace):
279-
current_thread_id = self.current_thread_id()
280-
if current_thread_id not in self._cache:
281-
self._cache[current_thread_id] = trace
282-
else:
283-
_logger.error(
284-
"Runtime instrumentation error. An active "
285-
"trace already exists in the cache on thread_id %s. Report "
286-
"this issue to New Relic support.\n ",
287-
current_thread_id,
288-
)
289-
return None
290-
291-
return current_thread_id
292-
293-
def thread_stop(self, thread_id):
294-
if thread_id:
295-
self._cache.pop(thread_id, None)
296-
297278
def pop_current(self, trace):
298279
"""Restore the trace's parent under the thread ID of the current
299280
executing thread."""

newrelic/hooks/adapter_asgiref.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from newrelic.api.time_trace import current_trace
1516
from newrelic.common.object_wrapper import wrap_function_wrapper
16-
from newrelic.core.trace_cache import trace_cache
17-
from newrelic.core.context import context_wrapper_async, ContextOf
17+
from newrelic.core.context import ContextOf, context_wrapper_async
1818

1919

2020
def _bind_thread_handler(loop, source_task, *args, **kwargs):
@@ -23,17 +23,15 @@ def _bind_thread_handler(loop, source_task, *args, **kwargs):
2323

2424
def thread_handler_wrapper(wrapped, instance, args, kwargs):
2525
task = _bind_thread_handler(*args, **kwargs)
26-
with ContextOf(id(task)):
26+
with ContextOf(trace_cache_id=id(task)):
2727
return wrapped(*args, **kwargs)
2828

2929

3030
def main_wrap_wrapper(wrapped, instance, args, kwargs):
3131
awaitable = wrapped(*args, **kwargs)
32-
return context_wrapper_async(awaitable, trace_cache().current_thread_id())
32+
return context_wrapper_async(awaitable, current_trace())
3333

3434

3535
def instrument_asgiref_sync(module):
36-
wrap_function_wrapper(module, 'SyncToAsync.thread_handler',
37-
thread_handler_wrapper)
38-
wrap_function_wrapper(module, 'AsyncToSync.main_wrap',
39-
main_wrap_wrapper)
36+
wrap_function_wrapper(module, "SyncToAsync.thread_handler", thread_handler_wrapper)
37+
wrap_function_wrapper(module, "AsyncToSync.main_wrap", main_wrap_wrapper)

newrelic/hooks/framework_fastapi.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,11 @@
1313
# limitations under the License.
1414

1515
from copy import copy
16-
from newrelic.api.time_trace import current_trace
16+
1717
from newrelic.api.function_trace import FunctionTraceWrapper
18-
from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper
18+
from newrelic.api.time_trace import current_trace
1919
from newrelic.common.object_names import callable_name
20-
from newrelic.core.trace_cache import trace_cache
21-
22-
23-
def use_context(trace):
24-
25-
@function_wrapper
26-
def context_wrapper(wrapped, instance, args, kwargs):
27-
cache = trace_cache()
28-
thread_id = cache.thread_start(trace)
29-
try:
30-
return wrapped(*args, **kwargs)
31-
finally:
32-
cache.thread_stop(thread_id)
33-
34-
return context_wrapper
20+
from newrelic.common.object_wrapper import wrap_function_wrapper
3521

3622

3723
def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
@@ -41,12 +27,9 @@ def wrap_run_endpoint_function(wrapped, instance, args, kwargs):
4127
name = callable_name(dependant.call)
4228
trace.transaction.set_transaction_name(name)
4329

44-
if not kwargs["is_coroutine"]:
45-
dependant = kwargs["dependant"] = copy(dependant)
46-
dependant.call = use_context(trace)(FunctionTraceWrapper(dependant.call))
47-
return wrapped(*args, **kwargs)
48-
else:
49-
return FunctionTraceWrapper(wrapped, name=name)(*args, **kwargs)
30+
dependant = kwargs["dependant"] = copy(dependant)
31+
dependant.call = FunctionTraceWrapper(dependant.call)
32+
return wrapped(*args, **kwargs)
5033

5134
return wrapped(*args, **kwargs)
5235

newrelic/hooks/framework_starlette.py

Lines changed: 15 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
wrap_function_wrapper,
2626
)
2727
from newrelic.core.config import should_ignore_error
28-
from newrelic.core.context import context_wrapper, current_thread_id
29-
from newrelic.core.trace_cache import trace_cache
28+
from newrelic.core.context import ContextOf, context_wrapper
3029

3130

3231
def framework_details():
@@ -43,30 +42,10 @@ def bind_exc(request, exc, *args, **kwargs):
4342
return exc
4443

4544

46-
class RequestContext(object):
47-
def __init__(self, request):
48-
self.request = request
49-
self.force_propagate = False
50-
self.thread_id = None
51-
52-
def __enter__(self):
53-
trace = getattr(self.request, "_nr_trace", None)
54-
self.force_propagate = trace and current_trace() is None
55-
56-
# Propagate trace context onto the current task
57-
if self.force_propagate:
58-
self.thread_id = trace_cache().thread_start(trace)
59-
60-
def __exit__(self, exc, value, tb):
61-
# Remove any context from the current thread as it was force propagated above
62-
if self.force_propagate:
63-
trace_cache().thread_stop(self.thread_id)
64-
65-
6645
@function_wrapper
6746
def route_naming_wrapper(wrapped, instance, args, kwargs):
6847

69-
with RequestContext(bind_request(*args, **kwargs)):
48+
with ContextOf(request=bind_request(*args, **kwargs)):
7049
transaction = current_transaction()
7150
if transaction:
7251
transaction.set_transaction_name(callable_name(wrapped), priority=2)
@@ -136,16 +115,14 @@ def wrap_add_middleware(wrapped, instance, args, kwargs):
136115
return wrapped(wrap_middleware(middleware), *args, **kwargs)
137116

138117

139-
def bind_middleware_starlette(
140-
debug=False, routes=None, middleware=None, *args, **kwargs
141-
):
118+
def bind_middleware_starlette(debug=False, routes=None, middleware=None, *args, **kwargs): # pylint: disable=W1113
142119
return middleware
143120

144121

145122
def wrap_starlette(wrapped, instance, args, kwargs):
146123
middlewares = bind_middleware_starlette(*args, **kwargs)
147124
if middlewares:
148-
for middleware in middlewares:
125+
for middleware in middlewares: # pylint: disable=E1133
149126
cls = getattr(middleware, "cls", None)
150127
if cls and not hasattr(cls, "__wrapped__"):
151128
middleware.cls = wrap_middleware(cls)
@@ -171,11 +148,9 @@ async def wrap_exception_handler_async(coro, exc):
171148

172149
def wrap_exception_handler(wrapped, instance, args, kwargs):
173150
if is_coroutine_function(wrapped):
174-
return wrap_exception_handler_async(
175-
FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs)
176-
)
151+
return wrap_exception_handler_async(FunctionTraceWrapper(wrapped)(*args, **kwargs), bind_exc(*args, **kwargs))
177152
else:
178-
with RequestContext(bind_request(*args, **kwargs)):
153+
with ContextOf(request=bind_request(*args, **kwargs)):
179154
response = FunctionTraceWrapper(wrapped)(*args, **kwargs)
180155
record_response_error(response, bind_exc(*args, **kwargs))
181156
return response
@@ -190,9 +165,7 @@ def wrap_server_error_handler(wrapped, instance, args, kwargs):
190165

191166

192167
def wrap_add_exception_handler(wrapped, instance, args, kwargs):
193-
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(
194-
*args, **kwargs
195-
)
168+
exc_class_or_status_code, handler, args, kwargs = bind_add_exception_handler(*args, **kwargs)
196169
handler = FunctionWrapper(handler, wrap_exception_handler)
197170
return wrapped(exc_class_or_status_code, handler, *args, **kwargs)
198171

@@ -217,7 +190,7 @@ async def wrap_run_in_threadpool(wrapped, instance, args, kwargs):
217190
return await wrapped(*args, **kwargs)
218191

219192
func, args, kwargs = bind_run_in_threadpool(*args, **kwargs)
220-
func = context_wrapper(func, current_thread_id())
193+
func = context_wrapper(func, trace)
221194

222195
return await wrapped(func, *args, **kwargs)
223196

@@ -241,35 +214,21 @@ def instrument_starlette_requests(module):
241214

242215

243216
def instrument_starlette_middleware_errors(module):
244-
wrap_function_wrapper(
245-
module, "ServerErrorMiddleware.__call__", error_middleware_wrapper
246-
)
217+
wrap_function_wrapper(module, "ServerErrorMiddleware.__call__", error_middleware_wrapper)
247218

248-
wrap_function_wrapper(
249-
module, "ServerErrorMiddleware.__init__", wrap_server_error_handler
250-
)
219+
wrap_function_wrapper(module, "ServerErrorMiddleware.__init__", wrap_server_error_handler)
251220

252-
wrap_function_wrapper(
253-
module, "ServerErrorMiddleware.error_response", wrap_exception_handler
254-
)
221+
wrap_function_wrapper(module, "ServerErrorMiddleware.error_response", wrap_exception_handler)
255222

256-
wrap_function_wrapper(
257-
module, "ServerErrorMiddleware.debug_response", wrap_exception_handler
258-
)
223+
wrap_function_wrapper(module, "ServerErrorMiddleware.debug_response", wrap_exception_handler)
259224

260225

261226
def instrument_starlette_exceptions(module):
262-
wrap_function_wrapper(
263-
module, "ExceptionMiddleware.__call__", error_middleware_wrapper
264-
)
227+
wrap_function_wrapper(module, "ExceptionMiddleware.__call__", error_middleware_wrapper)
265228

266-
wrap_function_wrapper(
267-
module, "ExceptionMiddleware.http_exception", wrap_exception_handler
268-
)
229+
wrap_function_wrapper(module, "ExceptionMiddleware.http_exception", wrap_exception_handler)
269230

270-
wrap_function_wrapper(
271-
module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler
272-
)
231+
wrap_function_wrapper(module, "ExceptionMiddleware.add_exception_handler", wrap_add_exception_handler)
273232

274233

275234
def instrument_starlette_background_task(module):

0 commit comments

Comments
 (0)