-
Notifications
You must be signed in to change notification settings - Fork 761
Rewrite FastAPI instrumentor middleware stack to be failsafe #3664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
50120ce
8a1a39e
2150e5c
2faf005
5665a0a
425a7f3
6c48703
55e9c43
a7a4949
5632d1b
0203e4b
4c17f00
56139a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -191,7 +191,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A | |
from starlette.applications import Starlette | ||
from starlette.middleware.errors import ServerErrorMiddleware | ||
from starlette.routing import Match | ||
from starlette.types import ASGIApp | ||
from starlette.types import ASGIApp, Receive, Scope, Send | ||
|
||
from opentelemetry.instrumentation._semconv import ( | ||
_get_schema_url, | ||
|
@@ -210,7 +210,8 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A | |
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor | ||
from opentelemetry.metrics import MeterProvider, get_meter | ||
from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE | ||
from opentelemetry.trace import TracerProvider, get_tracer | ||
from opentelemetry.trace import TracerProvider, get_current_span, get_tracer | ||
from opentelemetry.trace.status import Status, StatusCode | ||
from opentelemetry.util.http import ( | ||
get_excluded_urls, | ||
parse_excluded_urls, | ||
|
@@ -242,7 +243,7 @@ def instrument_app( | |
http_capture_headers_server_response: list[str] | None = None, | ||
http_capture_headers_sanitize_fields: list[str] | None = None, | ||
exclude_spans: list[Literal["receive", "send"]] | None = None, | ||
): | ||
): # pylint: disable=too-many-locals | ||
"""Instrument an uninstrumented FastAPI application. | ||
|
||
Args: | ||
|
@@ -289,15 +290,16 @@ def instrument_app( | |
schema_url=_get_schema_url(sem_conv_opt_in_mode), | ||
) | ||
|
||
# Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware | ||
# as the outermost middleware. | ||
# Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able | ||
# to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do. | ||
|
||
# In order to make traces available at any stage of the request | ||
# processing - including exception handling - we wrap ourselves as | ||
# the new, outermost middleware. However in order to prevent | ||
# exceptions from user-provided hooks of tearing through, we wrap | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The wrapping is no longer here. I also don't think this comment is an improvement. It's now less clear than before that this approach ensures the correct http status code. |
||
# them to return without failure unconditionally. | ||
def build_middleware_stack(self: Starlette) -> ASGIApp: | ||
inner_server_error_middleware: ASGIApp = ( # type: ignore | ||
self._original_build_middleware_stack() # type: ignore | ||
) | ||
|
||
otel_middleware = OpenTelemetryMiddleware( | ||
inner_server_error_middleware, | ||
excluded_urls=excluded_urls, | ||
|
@@ -313,6 +315,7 @@ def build_middleware_stack(self: Starlette) -> ASGIApp: | |
http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, | ||
exclude_spans=exclude_spans, | ||
) | ||
|
||
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware | ||
# are handled. | ||
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that | ||
|
@@ -339,6 +342,28 @@ def build_middleware_stack(self: Starlette) -> ASGIApp: | |
app, | ||
) | ||
|
||
class ExceptionHandlerMiddleware: | ||
def __init__(self, app): | ||
self.app = app | ||
|
||
async def __call__( | ||
self, scope: Scope, receive: Receive, send: Send | ||
) -> None: | ||
try: | ||
await self.app(scope, receive, send) | ||
except Exception as exc: # pylint: disable=broad-exception-caught | ||
span = get_current_span() | ||
span.record_exception(exc) | ||
span.set_status( | ||
Status( | ||
status_code=StatusCode.ERROR, | ||
description=f"{type(exc).__name__}: {exc}", | ||
) | ||
) | ||
raise | ||
|
||
app.add_middleware(ExceptionHandlerMiddleware) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't actually know where this middleware is going to go, it'd be clearer if this was inserted in build_middleware_stack. |
||
|
||
app._is_instrumented_by_opentelemetry = True | ||
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: | ||
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,6 +59,9 @@ | |
from opentelemetry.semconv._incubating.attributes.net_attributes import ( | ||
NET_HOST_PORT, | ||
) | ||
from opentelemetry.semconv.attributes.exception_attributes import ( | ||
EXCEPTION_TYPE, | ||
) | ||
from opentelemetry.semconv.attributes.http_attributes import ( | ||
HTTP_REQUEST_METHOD, | ||
HTTP_RESPONSE_STATUS_CODE, | ||
|
@@ -70,6 +73,7 @@ | |
from opentelemetry.semconv.attributes.url_attributes import URL_SCHEME | ||
from opentelemetry.test.globals_test import reset_trace_globals | ||
from opentelemetry.test.test_base import TestBase | ||
from opentelemetry.trace.status import StatusCode | ||
from opentelemetry.util._importlib_metadata import entry_points | ||
from opentelemetry.util.http import ( | ||
OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, | ||
|
@@ -1877,3 +1881,154 @@ def test_custom_header_not_present_in_non_recording_span(self): | |
self.assertEqual(200, resp.status_code) | ||
span_list = self.memory_exporter.get_finished_spans() | ||
self.assertEqual(len(span_list), 0) | ||
|
||
|
||
class TestTraceableExceptionHandling(TestBase): | ||
"""Tests to ensure FastAPI exception handlers are only executed once and with a valid context""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
|
||
self.app = fastapi.FastAPI() | ||
|
||
otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) | ||
self.client = TestClient(self.app) | ||
self.tracer = self.tracer_provider.get_tracer(__name__) | ||
self.executed = 0 | ||
self.request_trace_id = None | ||
self.error_trace_id = None | ||
|
||
def tearDown(self) -> None: | ||
super().tearDown() | ||
with self.disable_logging(): | ||
otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) | ||
|
||
def test_error_handler_context(self): | ||
"""OTEL tracing contexts must be available during error handler execution""" | ||
|
||
@self.app.exception_handler(Exception) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please test if the exception and http status code are recorded on the span in this case. |
||
async def _(*_): | ||
self.error_trace_id = ( | ||
trace.get_current_span().get_span_context().trace_id | ||
) | ||
|
||
@self.app.get("/foobar") | ||
async def _(): | ||
self.request_trace_id = ( | ||
trace.get_current_span().get_span_context().trace_id | ||
) | ||
raise UnhandledException("Test Exception") | ||
|
||
try: | ||
self.client.get( | ||
"/foobar", | ||
) | ||
except Exception: # pylint: disable=W0718 | ||
pass | ||
|
||
self.assertIsNotNone(self.request_trace_id) | ||
self.assertEqual(self.request_trace_id, self.error_trace_id) | ||
|
||
def test_error_handler_side_effects(self): | ||
"""FastAPI default exception handlers (aka error handlers) must be executed exactly once per exception""" | ||
|
||
@self.app.exception_handler(Exception) | ||
async def _(*_): | ||
self.executed += 1 | ||
|
||
@self.app.get("/foobar") | ||
async def _(): | ||
raise UnhandledException("Test Exception") | ||
|
||
try: | ||
self.client.get( | ||
"/foobar", | ||
) | ||
except Exception: # pylint: disable=W0718 | ||
pass | ||
|
||
self.assertEqual(self.executed, 1) | ||
|
||
def test_exception_span_recording(self): | ||
"""Exception are always recorded in the active span""" | ||
|
||
@self.app.get("/foobar") | ||
async def _(): | ||
raise UnhandledException("Test Exception") | ||
|
||
try: | ||
self.client.get( | ||
"/foobar", | ||
) | ||
except Exception: # pylint: disable=W0718 | ||
pass | ||
|
||
spans = self.memory_exporter.get_finished_spans() | ||
|
||
self.assertEqual(len(spans), 3) | ||
span = spans[2] | ||
self.assertEqual(span.status.status_code, StatusCode.ERROR) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please also check the http status code on the span. |
||
self.assertEqual(len(span.events), 1) | ||
event = span.events[0] | ||
self.assertEqual(event.name, "exception") | ||
assert event.attributes is not None | ||
self.assertEqual( | ||
event.attributes.get(EXCEPTION_TYPE), | ||
f"{__name__}.UnhandledException", | ||
) | ||
|
||
|
||
class TestFailsafeHooks(TestBase): | ||
"""Tests to ensure FastAPI instrumentation hooks don't tear through""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
|
||
self.app = fastapi.FastAPI() | ||
|
||
@self.app.get("/foobar") | ||
async def _(): | ||
return {"message": "Hello World"} | ||
|
||
def failing_hook(*_): | ||
raise UnhandledException("Hook Exception") | ||
|
||
otel_fastapi.FastAPIInstrumentor().instrument_app( | ||
self.app, | ||
server_request_hook=failing_hook, | ||
client_request_hook=failing_hook, | ||
client_response_hook=failing_hook, | ||
) | ||
self.client = TestClient(self.app) | ||
|
||
def tearDown(self) -> None: | ||
super().tearDown() | ||
with self.disable_logging(): | ||
otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) | ||
|
||
def test_failsafe_hooks(self): | ||
"""Crashing hooks must not tear through""" | ||
resp = self.client.get( | ||
"/foobar", | ||
) | ||
|
||
self.assertEqual(200, resp.status_code) | ||
|
||
def test_failsafe_error_recording(self): | ||
"""Failing hooks must record the exception on the span""" | ||
self.client.get( | ||
"/foobar", | ||
) | ||
|
||
spans = self.memory_exporter.get_finished_spans() | ||
|
||
self.assertEqual(len(spans), 3) | ||
span = spans[0] | ||
self.assertEqual(len(span.events), 1) | ||
event = span.events[0] | ||
self.assertEqual(event.name, "exception") | ||
assert event.attributes is not None | ||
self.assertEqual( | ||
event.attributes.get(EXCEPTION_TYPE), | ||
f"{__name__}.UnhandledException", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be checked outside
wrapper