Skip to content

Commit 2739fbb

Browse files
committed
move ExceptionHandlingMiddleware as the outermost inner middleware
Also improve code documentation and add another test.
1 parent 8fe25f3 commit 2739fbb

File tree

2 files changed

+110
-34
lines changed

2 files changed

+110
-34
lines changed

instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,77 @@ def instrument_app(
291291
)
292292

293293
def build_middleware_stack(self: Starlette) -> ASGIApp:
294-
inner_server_error_middleware: ASGIApp = ( # type: ignore
294+
# Define an additional middleware for exception handling that gets
295+
# added as a regular user middleware.
296+
# Normally, `opentelemetry.trace.use_span` covers the recording of
297+
# exceptions into the active span, but `OpenTelemetryMiddleware`
298+
# ends the span too early before the exception can be recorded.
299+
class ExceptionHandlerMiddleware:
300+
def __init__(self, app):
301+
self.app = app
302+
303+
async def __call__(
304+
self, scope: Scope, receive: Receive, send: Send
305+
) -> None:
306+
try:
307+
await self.app(scope, receive, send)
308+
except Exception as exc: # pylint: disable=broad-exception-caught
309+
span = get_current_span()
310+
span.record_exception(exc)
311+
span.set_status(
312+
Status(
313+
status_code=StatusCode.ERROR,
314+
description=f"{type(exc).__name__}: {exc}",
315+
)
316+
)
317+
raise
318+
319+
# For every possible use case of error handling, exception
320+
# handling, trace availability in exception handlers and
321+
# automatic exception recording to work, we need to make a
322+
# series of wrapping and re-wrapping middlewares.
323+
324+
# First, grab the original middleware stack from Starlette. It
325+
# comprises a stack of
326+
# `ServerErrorMiddleware` -> [user defined middlewares] -> `ExceptionMiddleware`
327+
inner_server_error_middleware: ServerErrorMiddleware = ( # type: ignore
295328
self._original_build_middleware_stack() # type: ignore
296329
)
297330

331+
if not isinstance(
332+
inner_server_error_middleware, ServerErrorMiddleware
333+
):
334+
# Oops, something changed about how Starlette creates middleware stacks
335+
_logger.error(
336+
"Cannot instrument FastAPI as the expected middleware stack has changed"
337+
)
338+
return inner_server_error_middleware
339+
340+
# We take [user defined middlewares] -> `ExceptionHandlerMiddleware`
341+
# out of the outermost `ServerErrorMiddleware` and instead pass
342+
# it to our own `ExceptionHandlerMiddleware`
343+
exception_middleware = ExceptionHandlerMiddleware(
344+
inner_server_error_middleware.app
345+
)
346+
347+
# Now, we create a new `ServerErrorMiddleware` that wraps
348+
# `ExceptionHandlerMiddleware` but otherwise uses the same
349+
# original `handler` and debug setting. The end result is a
350+
# middleware stack that's identical to the original stack except
351+
# all user middlewares are covered by our
352+
# `ExceptionHandlerMiddleware`.
353+
error_middleware = ServerErrorMiddleware(
354+
app=exception_middleware,
355+
handler=inner_server_error_middleware.handler,
356+
debug=inner_server_error_middleware.debug,
357+
)
358+
359+
# Finally, we wrap the stack above in our actual OTEL
360+
# middleware. As a result, an active tracing context exists for
361+
# every use case of user-defined error and exception handlers as
362+
# well as automatic recording of exceptions in active spans.
298363
otel_middleware = OpenTelemetryMiddleware(
299-
inner_server_error_middleware,
364+
error_middleware,
300365
excluded_urls=excluded_urls,
301366
default_span_details=_get_default_span_details,
302367
server_request_hook=server_request_hook,
@@ -311,10 +376,14 @@ def build_middleware_stack(self: Starlette) -> ASGIApp:
311376
exclude_spans=exclude_spans,
312377
)
313378

314-
# Wrap in an outer layer of ServerErrorMiddleware so that any exceptions raised in OpenTelemetryMiddleware
315-
# are handled.
316-
# This should not happen unless there is a bug in OpenTelemetryMiddleware, but if there is we don't want that
317-
# to impact the user's application just because we wrapped the middlewares in this order.
379+
# Ultimately, wrap everything in another default
380+
# `ServerErrorMiddleware` (w/o user handlers) so that any
381+
# exceptions raised in `OpenTelemetryMiddleware` are handled.
382+
#
383+
# This should not happen unless there is a bug in
384+
# OpenTelemetryMiddleware, but if there is we don't want that to
385+
# impact the user's application just because we wrapped the
386+
# middlewares in this order.
318387
return ServerErrorMiddleware(
319388
app=otel_middleware,
320389
)
@@ -327,33 +396,6 @@ def build_middleware_stack(self: Starlette) -> ASGIApp:
327396
app,
328397
)
329398

330-
# Define an additional middleware for exception handling that gets
331-
# added as a regular user middleware.
332-
# Normally, `opentelemetry.trace.use_span` covers the recording of
333-
# exceptions into the active span, but `OpenTelemetryMiddleware`
334-
# ends the span too early before the exception can be recorded.
335-
class ExceptionHandlerMiddleware:
336-
def __init__(self, app):
337-
self.app = app
338-
339-
async def __call__(
340-
self, scope: Scope, receive: Receive, send: Send
341-
) -> None:
342-
try:
343-
await self.app(scope, receive, send)
344-
except Exception as exc: # pylint: disable=broad-exception-caught
345-
span = get_current_span()
346-
span.record_exception(exc)
347-
span.set_status(
348-
Status(
349-
status_code=StatusCode.ERROR,
350-
description=f"{type(exc).__name__}: {exc}",
351-
)
352-
)
353-
raise
354-
355-
app.add_middleware(ExceptionHandlerMiddleware)
356-
357399
app._is_instrumented_by_opentelemetry = True
358400
if app not in _InstrumentedFastAPI._instrumented_fastapi_apps:
359401
_InstrumentedFastAPI._instrumented_fastapi_apps.add(app)

instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1965,7 +1965,7 @@ async def _():
19651965
self.assertEqual(self.executed, 1)
19661966

19671967
def test_exception_span_recording(self):
1968-
"""Exception are always recorded in the active span"""
1968+
"""Exceptions are always recorded in the active span"""
19691969

19701970
@self.app.get("/foobar")
19711971
async def _():
@@ -1994,6 +1994,40 @@ async def _():
19941994
f"{__name__}.UnhandledException",
19951995
)
19961996

1997+
def test_middleware_exceptions(self):
1998+
"""Exceptions from user middlewares are recorded in the active span"""
1999+
2000+
@self.app.get("/foobar")
2001+
async def _():
2002+
return PlainTextResponse("Hello World")
2003+
2004+
@self.app.middleware("http")
2005+
async def _(*_):
2006+
raise UnhandledException("Test Exception")
2007+
2008+
try:
2009+
self.client.get(
2010+
"/foobar",
2011+
)
2012+
except Exception: # pylint: disable=W0718
2013+
pass
2014+
2015+
spans = self.memory_exporter.get_finished_spans()
2016+
2017+
self.assertEqual(len(spans), 3)
2018+
span = spans[2]
2019+
self.assertEqual(span.name, "GET /foobar")
2020+
self.assertEqual(span.attributes.get(HTTP_STATUS_CODE), 500)
2021+
self.assertEqual(span.status.status_code, StatusCode.ERROR)
2022+
self.assertEqual(len(span.events), 1)
2023+
event = span.events[0]
2024+
self.assertEqual(event.name, "exception")
2025+
assert event.attributes is not None
2026+
self.assertEqual(
2027+
event.attributes.get(EXCEPTION_TYPE),
2028+
f"{__name__}.UnhandledException",
2029+
)
2030+
19972031

19982032
class TestFailsafeHooks(TestBase):
19992033
"""Tests to ensure FastAPI instrumentation hooks don't tear through"""

0 commit comments

Comments
 (0)