Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#3719](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3719))
- `opentelemetry-instrumentation-httpx`: fix missing metric response attributes when tracing is disabled
([#3615](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3615))
- `opentelemetry-instrumentation-fastapi`: Don't pass bounded server_request_hook when using `FastAPIInstrumentor.instrument()`
([#3701](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3701))

### Added

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def client_response_hook(span: Span, scope: dict[str, Any], message: dict[str, A
import functools
import logging
import types
from typing import Collection, Literal
from typing import Any, Collection, Literal
from weakref import WeakSet as _WeakSet

import fastapi
Expand Down Expand Up @@ -426,30 +426,9 @@ def uninstrument_app(app: fastapi.FastAPI):
def instrumentation_dependencies(self) -> Collection[str]:
return _instruments

def _instrument(self, **kwargs):
def _instrument(self, **kwargs: Any):
self._original_fastapi = fastapi.FastAPI
_InstrumentedFastAPI._tracer_provider = kwargs.get("tracer_provider")
_InstrumentedFastAPI._server_request_hook = kwargs.get(
"server_request_hook"
)
_InstrumentedFastAPI._client_request_hook = kwargs.get(
"client_request_hook"
)
_InstrumentedFastAPI._client_response_hook = kwargs.get(
"client_response_hook"
)
_InstrumentedFastAPI._http_capture_headers_server_request = kwargs.get(
"http_capture_headers_server_request"
)
_InstrumentedFastAPI._http_capture_headers_server_response = (
kwargs.get("http_capture_headers_server_response")
)
_InstrumentedFastAPI._http_capture_headers_sanitize_fields = (
kwargs.get("http_capture_headers_sanitize_fields")
)
_InstrumentedFastAPI._excluded_urls = kwargs.get("excluded_urls")
_InstrumentedFastAPI._meter_provider = kwargs.get("meter_provider")
_InstrumentedFastAPI._exclude_spans = kwargs.get("exclude_spans")
_InstrumentedFastAPI._instrument_kwargs = kwargs
fastapi.FastAPI = _InstrumentedFastAPI

def _uninstrument(self, **kwargs):
Expand All @@ -464,35 +443,16 @@ def _uninstrument(self, **kwargs):


class _InstrumentedFastAPI(fastapi.FastAPI):
_tracer_provider = None
_meter_provider = None
_excluded_urls = None
_server_request_hook: ServerRequestHook = None
_client_request_hook: ClientRequestHook = None
_client_response_hook: ClientResponseHook = None
_http_capture_headers_server_request: list[str] | None = None
_http_capture_headers_server_response: list[str] | None = None
_http_capture_headers_sanitize_fields: list[str] | None = None
_exclude_spans: list[Literal["receive", "send"]] | None = None
_instrument_kwargs: dict[str, Any] = {}

# Track instrumented app instances using weak references to avoid GC leaks
_instrumented_fastapi_apps = _WeakSet()
_instrumented_fastapi_apps: _WeakSet[fastapi.FastAPI] = _WeakSet()
_sem_conv_opt_in_mode = _StabilityMode.DEFAULT

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
FastAPIInstrumentor.instrument_app(
self,
server_request_hook=self._server_request_hook,
client_request_hook=self._client_request_hook,
client_response_hook=self._client_response_hook,
tracer_provider=self._tracer_provider,
meter_provider=self._meter_provider,
excluded_urls=self._excluded_urls,
http_capture_headers_server_request=self._http_capture_headers_server_request,
http_capture_headers_server_response=self._http_capture_headers_server_response,
http_capture_headers_sanitize_fields=self._http_capture_headers_sanitize_fields,
exclude_spans=self._exclude_spans,
self, **_InstrumentedFastAPI._instrument_kwargs
)
_InstrumentedFastAPI._instrumented_fastapi_apps.add(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import weakref as _weakref
from contextlib import ExitStack
from timeit import default_timer
from typing import Any, cast
from unittest.mock import Mock, call, patch

import fastapi
Expand Down Expand Up @@ -48,6 +49,7 @@
NumberDataPoint,
)
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.semconv._incubating.attributes.http_attributes import (
HTTP_FLAVOR,
HTTP_HOST,
Expand Down Expand Up @@ -1045,44 +1047,45 @@ async def _():
return app


class TestFastAPIManualInstrumentationHooks(TestBaseManualFastAPI):
_server_request_hook = None
_client_request_hook = None
_client_response_hook = None

def server_request_hook(self, span, scope):
if self._server_request_hook is not None:
self._server_request_hook(span, scope)

def client_request_hook(self, receive_span, scope, message):
if self._client_request_hook is not None:
self._client_request_hook(receive_span, scope, message)

def client_response_hook(self, send_span, scope, message):
if self._client_response_hook is not None:
self._client_response_hook(send_span, scope, message)

def test_hooks(self):
def server_request_hook(span, scope):
class TestFastAPIManualInstrumentationHooks(TestBaseFastAPI):
def _create_app(self):
def server_request_hook(span: trace.Span, scope: dict[str, Any]):
span.update_name("name from server hook")

def client_request_hook(receive_span, scope, message):
def client_request_hook(
receive_span: trace.Span,
scope: dict[str, Any],
message: dict[str, Any],
):
receive_span.update_name("name from client hook")
receive_span.set_attribute("attr-from-request-hook", "set")

def client_response_hook(send_span, scope, message):
def client_response_hook(
send_span: trace.Span,
scope: dict[str, Any],
message: dict[str, Any],
):
send_span.update_name("name from response hook")
send_span.set_attribute("attr-from-response-hook", "value")

self._server_request_hook = server_request_hook
self._client_request_hook = client_request_hook
self._client_response_hook = client_response_hook
self._instrumentor.instrument(
server_request_hook=server_request_hook,
client_request_hook=client_request_hook,
client_response_hook=client_response_hook,
)

app = self._create_fastapi_app()

return app

def test_hooks(self):
self._client.get("/foobar")
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
self.assertEqual(
len(spans), 3
) # 1 server span and 2 response spans (response start and body)

spans = cast(
list[ReadableSpan],
self.sorted_spans(self.memory_exporter.get_finished_spans()),
)
self.assertEqual(len(spans), 3)

server_span = spans[2]
self.assertEqual(server_span.name, "name from server hook")
Expand Down