diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f252e8290..51c963bd9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Fixed + +- `opentelemetry-instrumentation-fastapi`: Implemented failsafe middleware stack. + ([#3664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3664)) + ## Version 1.36.0/0.57b0 (2025-07-29) ### Fixed diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 7e72dbf11f..e6eb9c8dbb 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -268,7 +268,7 @@ def client_response_hook(span: Span, scope: Scope, message: dict[str, Any]): HTTP_SERVER_REQUEST_DURATION, ) from opentelemetry.semconv.trace import SpanAttributes -from opentelemetry.trace import set_span_in_context +from opentelemetry.trace import Span, set_span_in_context from opentelemetry.util.http import ( OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SERVER_REQUEST, @@ -646,9 +646,21 @@ def __init__( self.default_span_details = ( default_span_details or get_default_span_details ) - self.server_request_hook = server_request_hook - self.client_request_hook = client_request_hook - self.client_response_hook = client_response_hook + + def failsafe(func): + @wraps(func) + def wrapper(span: Span, *args, **kwargs): + if func is not None: + try: + func(span, *args, **kwargs) + except Exception as exc: # pylint: disable=broad-exception-caught + span.record_exception(exc) + + return wrapper + + self.server_request_hook = failsafe(server_request_hook) + self.client_request_hook = failsafe(client_request_hook) + self.client_response_hook = failsafe(client_response_hook) self.content_length_header = None self._sem_conv_opt_in_mode = sem_conv_opt_in_mode diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py index 8ba83985c6..9ebe2b27c2 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py @@ -191,7 +191,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from starlette.applications import Starlette from starlette.middleware.errors import ServerErrorMiddleware from starlette.routing import Match -from starlette.types import ASGIApp +from starlette.types import ASGIApp, Receive, Scope, Send from opentelemetry.instrumentation._semconv import ( _get_schema_url, @@ -210,7 +210,8 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A from opentelemetry.instrumentation.instrumentor import BaseInstrumentor from opentelemetry.metrics import MeterProvider, get_meter from opentelemetry.semconv.attributes.http_attributes import HTTP_ROUTE -from opentelemetry.trace import TracerProvider, get_tracer +from opentelemetry.trace import TracerProvider, get_current_span, get_tracer +from opentelemetry.trace.status import Status, StatusCode from opentelemetry.util.http import ( get_excluded_urls, parse_excluded_urls, @@ -242,7 +243,7 @@ def instrument_app( http_capture_headers_server_response: list[str] | None = None, http_capture_headers_sanitize_fields: list[str] | None = None, exclude_spans: list[Literal["receive", "send"]] | None = None, - ): + ): # pylint: disable=too-many-locals """Instrument an uninstrumented FastAPI application. Args: @@ -289,15 +290,16 @@ def instrument_app( schema_url=_get_schema_url(sem_conv_opt_in_mode), ) - # Instead of using `app.add_middleware` we monkey patch `build_middleware_stack` to insert our middleware - # as the outermost middleware. - # Otherwise `OpenTelemetryMiddleware` would have unhandled exceptions tearing through it and would not be able - # to faithfully record what is returned to the client since it technically cannot know what `ServerErrorMiddleware` is going to do. - + # In order to make traces available at any stage of the request + # processing - including exception handling - we wrap ourselves as + # the new, outermost middleware. However in order to prevent + # exceptions from user-provided hooks of tearing through, we wrap + # them to return without failure unconditionally. def build_middleware_stack(self: Starlette) -> ASGIApp: inner_server_error_middleware: ASGIApp = ( # type: ignore self._original_build_middleware_stack() # type: ignore ) + otel_middleware = OpenTelemetryMiddleware( inner_server_error_middleware, excluded_urls=excluded_urls, @@ -313,6 +315,7 @@ def build_middleware_stack(self: Starlette) -> ASGIApp: http_capture_headers_sanitize_fields=http_capture_headers_sanitize_fields, exclude_spans=exclude_spans, ) + # Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware # are handled. # This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that @@ -339,6 +342,28 @@ def build_middleware_stack(self: Starlette) -> ASGIApp: app, ) + class ExceptionHandlerMiddleware: + def __init__(self, app): + self.app = app + + async def __call__( + self, scope: Scope, receive: Receive, send: Send + ) -> None: + try: + await self.app(scope, receive, send) + except Exception as exc: # pylint: disable=broad-exception-caught + span = get_current_span() + span.record_exception(exc) + span.set_status( + Status( + status_code=StatusCode.ERROR, + description=f"{type(exc).__name__}: {exc}", + ) + ) + raise + + app.add_middleware(ExceptionHandlerMiddleware) + app._is_instrumented_by_opentelemetry = True if app not in _InstrumentedFastAPI._instrumented_fastapi_apps: _InstrumentedFastAPI._instrumented_fastapi_apps.add(app) diff --git a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py index 523c165f85..0dec9c61fe 100644 --- a/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py +++ b/instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py @@ -59,6 +59,9 @@ from opentelemetry.semconv._incubating.attributes.net_attributes import ( NET_HOST_PORT, ) +from opentelemetry.semconv.attributes.exception_attributes import ( + EXCEPTION_TYPE, +) from opentelemetry.semconv.attributes.http_attributes import ( HTTP_REQUEST_METHOD, HTTP_RESPONSE_STATUS_CODE, @@ -70,6 +73,7 @@ from opentelemetry.semconv.attributes.url_attributes import URL_SCHEME from opentelemetry.test.globals_test import reset_trace_globals from opentelemetry.test.test_base import TestBase +from opentelemetry.trace.status import StatusCode from opentelemetry.util._importlib_metadata import entry_points from opentelemetry.util.http import ( OTEL_INSTRUMENTATION_HTTP_CAPTURE_HEADERS_SANITIZE_FIELDS, @@ -1877,3 +1881,154 @@ def test_custom_header_not_present_in_non_recording_span(self): self.assertEqual(200, resp.status_code) span_list = self.memory_exporter.get_finished_spans() self.assertEqual(len(span_list), 0) + + +class TestTraceableExceptionHandling(TestBase): + """Tests to ensure FastAPI exception handlers are only executed once and with a valid context""" + + def setUp(self): + super().setUp() + + self.app = fastapi.FastAPI() + + otel_fastapi.FastAPIInstrumentor().instrument_app(self.app) + self.client = TestClient(self.app) + self.tracer = self.tracer_provider.get_tracer(__name__) + self.executed = 0 + self.request_trace_id = None + self.error_trace_id = None + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + def test_error_handler_context(self): + """OTEL tracing contexts must be available during error handler execution""" + + @self.app.exception_handler(Exception) + async def _(*_): + self.error_trace_id = ( + trace.get_current_span().get_span_context().trace_id + ) + + @self.app.get("/foobar") + async def _(): + self.request_trace_id = ( + trace.get_current_span().get_span_context().trace_id + ) + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except Exception: # pylint: disable=W0718 + pass + + self.assertIsNotNone(self.request_trace_id) + self.assertEqual(self.request_trace_id, self.error_trace_id) + + def test_error_handler_side_effects(self): + """FastAPI default exception handlers (aka error handlers) must be executed exactly once per exception""" + + @self.app.exception_handler(Exception) + async def _(*_): + self.executed += 1 + + @self.app.get("/foobar") + async def _(): + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except Exception: # pylint: disable=W0718 + pass + + self.assertEqual(self.executed, 1) + + def test_exception_span_recording(self): + """Exception are always recorded in the active span""" + + @self.app.get("/foobar") + async def _(): + raise UnhandledException("Test Exception") + + try: + self.client.get( + "/foobar", + ) + except Exception: # pylint: disable=W0718 + pass + + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 3) + span = spans[2] + self.assertEqual(span.status.status_code, StatusCode.ERROR) + self.assertEqual(len(span.events), 1) + event = span.events[0] + self.assertEqual(event.name, "exception") + assert event.attributes is not None + self.assertEqual( + event.attributes.get(EXCEPTION_TYPE), + f"{__name__}.UnhandledException", + ) + + +class TestFailsafeHooks(TestBase): + """Tests to ensure FastAPI instrumentation hooks don't tear through""" + + def setUp(self): + super().setUp() + + self.app = fastapi.FastAPI() + + @self.app.get("/foobar") + async def _(): + return {"message": "Hello World"} + + def failing_hook(*_): + raise UnhandledException("Hook Exception") + + otel_fastapi.FastAPIInstrumentor().instrument_app( + self.app, + server_request_hook=failing_hook, + client_request_hook=failing_hook, + client_response_hook=failing_hook, + ) + self.client = TestClient(self.app) + + def tearDown(self) -> None: + super().tearDown() + with self.disable_logging(): + otel_fastapi.FastAPIInstrumentor().uninstrument_app(self.app) + + def test_failsafe_hooks(self): + """Crashing hooks must not tear through""" + resp = self.client.get( + "/foobar", + ) + + self.assertEqual(200, resp.status_code) + + def test_failsafe_error_recording(self): + """Failing hooks must record the exception on the span""" + self.client.get( + "/foobar", + ) + + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 3) + span = spans[0] + self.assertEqual(len(span.events), 1) + event = span.events[0] + self.assertEqual(event.name, "exception") + assert event.attributes is not None + self.assertEqual( + event.attributes.get(EXCEPTION_TYPE), + f"{__name__}.UnhandledException", + )