Skip to content

Commit b2c3c4e

Browse files
Kludexemdneto
andauthored
Don't pass bounded server_request_hook when using FastAPIInstrumentor.instrument(server_request_hook=...) (#3701)
* Don't pass bounded `server_request_hook` when using `FastAPIInstrumentor.instrument(server_request_hook=...)` * try now * changelog Signed-off-by: emdneto <[email protected]> --------- Signed-off-by: emdneto <[email protected]> Co-authored-by: emdneto <[email protected]>
1 parent edb34e6 commit b2c3c4e

File tree

3 files changed

+40
-75
lines changed

3 files changed

+40
-75
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
([#3719](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3719))
3232
- `opentelemetry-instrumentation-httpx`: fix missing metric response attributes when tracing is disabled
3333
([#3615](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3615))
34+
- `opentelemetry-instrumentation-fastapi`: Don't pass bounded server_request_hook when using `FastAPIInstrumentor.instrument()`
35+
([#3701](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3701))
3436

3537
### Added
3638

instrumentation/opentelemetry-instrumentation-fastapi/src/opentelemetry/instrumentation/fastapi/__init__.py

Lines changed: 7 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
185185
import functools
186186
import logging
187187
import types
188-
from typing import Collection, Literal
188+
from typing import Any, Collection, Literal
189189
from weakref import WeakSet as _WeakSet
190190

191191
import fastapi
@@ -426,30 +426,9 @@ def uninstrument_app(app: fastapi.FastAPI):
426426
def instrumentation_dependencies(self) -> Collection[str]:
427427
return _instruments
428428

429-
def _instrument(self, **kwargs):
429+
def _instrument(self, **kwargs: Any):
430430
self._original_fastapi = fastapi.FastAPI
431-
_InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider")
432-
_InstrumentedFastAPI._server_request_hook = kwargs.get(
433-
"server_request_hook"
434-
)
435-
_InstrumentedFastAPI._client_request_hook = kwargs.get(
436-
"client_request_hook"
437-
)
438-
_InstrumentedFastAPI._client_response_hook = kwargs.get(
439-
"client_response_hook"
440-
)
441-
_InstrumentedFastAPI._http_capture_headers_server_request = kwargs.get(
442-
"http_capture_headers_server_request"
443-
)
444-
_InstrumentedFastAPI._http_capture_headers_server_response = (
445-
kwargs.get("http_capture_headers_server_response")
446-
)
447-
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
448-
kwargs.get("http_capture_headers_sanitize_fields")
449-
)
450-
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls")
451-
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider")
452-
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans")
431+
_InstrumentedFastAPI._instrument_kwargs = kwargs
453432
fastapi.FastAPI = _InstrumentedFastAPI
454433

455434
def _uninstrument(self, **kwargs):
@@ -464,35 +443,16 @@ def _uninstrument(self, **kwargs):
464443

465444

466445
class _InstrumentedFastAPI(fastapi.FastAPI):
467-
_tracer_provider = None
468-
_meter_provider = None
469-
_excluded_urls = None
470-
_server_request_hook: ServerRequestHook = None
471-
_client_request_hook: ClientRequestHook = None
472-
_client_response_hook: ClientResponseHook = None
473-
_http_capture_headers_server_request: list[str] | None = None
474-
_http_capture_headers_server_response: list[str] | None = None
475-
_http_capture_headers_sanitize_fields: list[str] | None = None
476-
_exclude_spans: list[Literal["receive", "send"]] | None = None
446+
_instrument_kwargs: dict[str, Any] = {}
477447

478448
# Track instrumented app instances using weak references to avoid GC leaks
479-
_instrumented_fastapi_apps = _WeakSet()
449+
_instrumented_fastapi_apps: _WeakSet[fastapi.FastAPI] = _WeakSet()
480450
_sem_conv_opt_in_mode = _StabilityMode.DEFAULT
481451

482-
def __init__(self, *args, **kwargs):
452+
def __init__(self, *args: Any, **kwargs: Any):
483453
super().__init__(*args, **kwargs)
484454
FastAPIInstrumentor.instrument_app(
485-
self,
486-
server_request_hook=self._server_request_hook,
487-
client_request_hook=self._client_request_hook,
488-
client_response_hook=self._client_response_hook,
489-
tracer_provider=self._tracer_provider,
490-
meter_provider=self._meter_provider,
491-
excluded_urls=self._excluded_urls,
492-
http_capture_headers_server_request=self._http_capture_headers_server_request,
493-
http_capture_headers_server_response=self._http_capture_headers_server_response,
494-
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields,
495-
exclude_spans=self._exclude_spans,
455+
self, **_InstrumentedFastAPI._instrument_kwargs
496456
)
497457
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)
498458

instrumentation/opentelemetry-instrumentation-fastapi/tests/test_fastapi_instrumentation.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import weakref as _weakref
2121
from contextlib import ExitStack
2222
from timeit import default_timer
23+
from typing import Any, cast
2324
from unittest.mock import Mock, call, patch
2425

2526
import fastapi
@@ -48,6 +49,7 @@
4849
NumberDataPoint,
4950
)
5051
from opentelemetry.sdk.resources import Resource
52+
from opentelemetry.sdk.trace import ReadableSpan
5153
from opentelemetry.semconv._incubating.attributes.http_attributes import (
5254
HTTP_FLAVOR,
5355
HTTP_HOST,
@@ -1045,44 +1047,45 @@ async def _():
10451047
return app
10461048

10471049

1048-
class TestFastAPIManualInstrumentationHooks(TestBaseManualFastAPI):
1049-
_server_request_hook = None
1050-
_client_request_hook = None
1051-
_client_response_hook = None
1052-
1053-
def server_request_hook(self, span, scope):
1054-
if self._server_request_hook is not None:
1055-
self._server_request_hook(span, scope)
1056-
1057-
def client_request_hook(self, receive_span, scope, message):
1058-
if self._client_request_hook is not None:
1059-
self._client_request_hook(receive_span, scope, message)
1060-
1061-
def client_response_hook(self, send_span, scope, message):
1062-
if self._client_response_hook is not None:
1063-
self._client_response_hook(send_span, scope, message)
1064-
1065-
def test_hooks(self):
1066-
def server_request_hook(span, scope):
1050+
class TestFastAPIManualInstrumentationHooks(TestBaseFastAPI):
1051+
def _create_app(self):
1052+
def server_request_hook(span: trace.Span, scope: dict[str, Any]):
10671053
span.update_name("name from server hook")
10681054

1069-
def client_request_hook(receive_span, scope, message):
1055+
def client_request_hook(
1056+
receive_span: trace.Span,
1057+
scope: dict[str, Any],
1058+
message: dict[str, Any],
1059+
):
10701060
receive_span.update_name("name from client hook")
10711061
receive_span.set_attribute("attr-from-request-hook", "set")
10721062

1073-
def client_response_hook(send_span, scope, message):
1063+
def client_response_hook(
1064+
send_span: trace.Span,
1065+
scope: dict[str, Any],
1066+
message: dict[str, Any],
1067+
):
10741068
send_span.update_name("name from response hook")
10751069
send_span.set_attribute("attr-from-response-hook", "value")
10761070

1077-
self._server_request_hook = server_request_hook
1078-
self._client_request_hook = client_request_hook
1079-
self._client_response_hook = client_response_hook
1071+
self._instrumentor.instrument(
1072+
server_request_hook=server_request_hook,
1073+
client_request_hook=client_request_hook,
1074+
client_response_hook=client_response_hook,
1075+
)
10801076

1077+
app = self._create_fastapi_app()
1078+
1079+
return app
1080+
1081+
def test_hooks(self):
10811082
self._client.get("/foobar")
1082-
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
1083-
self.assertEqual(
1084-
len(spans), 3
1085-
) # 1 server span and 2 response spans (response start and body)
1083+
1084+
spans = cast(
1085+
list[ReadableSpan],
1086+
self.sorted_spans(self.memory_exporter.get_finished_spans()),
1087+
)
1088+
self.assertEqual(len(spans), 3)
10861089

10871090
server_span = spans[2]
10881091
self.assertEqual(server_span.name, "name from server hook")

0 commit comments

Comments
 (0)