|
17 | 17 | import unittest |
18 | 18 | from contextlib import ExitStack |
19 | 19 | from timeit import default_timer |
| 20 | +from typing import Any, cast |
20 | 21 | from unittest.mock import Mock, call, patch |
21 | 22 |
|
22 | 23 | import fastapi |
|
46 | 47 | NumberDataPoint, |
47 | 48 | ) |
48 | 49 | from opentelemetry.sdk.resources import Resource |
| 50 | +from opentelemetry.sdk.trace import ReadableSpan |
49 | 51 | from opentelemetry.semconv._incubating.attributes.http_attributes import ( |
50 | 52 | HTTP_FLAVOR, |
51 | 53 | HTTP_HOST, |
@@ -1019,43 +1021,44 @@ async def _(): |
1019 | 1021 |
|
1020 | 1022 |
|
1021 | 1023 | class TestFastAPIManualInstrumentationHooks(TestBaseManualFastAPI): |
1022 | | - _server_request_hook = None |
1023 | | - _client_request_hook = None |
1024 | | - _client_response_hook = None |
1025 | | - |
1026 | | - def server_request_hook(self, span, scope): |
1027 | | - if self._server_request_hook is not None: |
1028 | | - self._server_request_hook(span, scope) |
1029 | | - |
1030 | | - def client_request_hook(self, receive_span, scope, message): |
1031 | | - if self._client_request_hook is not None: |
1032 | | - self._client_request_hook(receive_span, scope, message) |
1033 | | - |
1034 | | - def client_response_hook(self, send_span, scope, message): |
1035 | | - if self._client_response_hook is not None: |
1036 | | - self._client_response_hook(send_span, scope, message) |
1037 | | - |
1038 | | - def test_hooks(self): |
1039 | | - def server_request_hook(span, scope): |
| 1024 | + def _create_app(self): |
| 1025 | + def server_request_hook(span: trace.Span, scope: dict[str, Any]): |
1040 | 1026 | span.update_name("name from server hook") |
1041 | 1027 |
|
1042 | | - def client_request_hook(receive_span, scope, message): |
| 1028 | + def client_request_hook( |
| 1029 | + receive_span: trace.Span, |
| 1030 | + scope: dict[str, Any], |
| 1031 | + message: dict[str, Any], |
| 1032 | + ): |
1043 | 1033 | receive_span.update_name("name from client hook") |
1044 | 1034 | receive_span.set_attribute("attr-from-request-hook", "set") |
1045 | 1035 |
|
1046 | | - def client_response_hook(send_span, scope, message): |
| 1036 | + def client_response_hook( |
| 1037 | + send_span: trace.Span, |
| 1038 | + scope: dict[str, Any], |
| 1039 | + message: dict[str, Any], |
| 1040 | + ): |
1047 | 1041 | send_span.update_name("name from response hook") |
1048 | 1042 | send_span.set_attribute("attr-from-response-hook", "value") |
1049 | 1043 |
|
1050 | | - self._server_request_hook = server_request_hook |
1051 | | - self._client_request_hook = client_request_hook |
1052 | | - self._client_response_hook = client_response_hook |
| 1044 | + self._instrumentor.instrument( |
| 1045 | + server_request_hook=server_request_hook, |
| 1046 | + client_request_hook=client_request_hook, |
| 1047 | + client_response_hook=client_response_hook, |
| 1048 | + ) |
1053 | 1049 |
|
| 1050 | + app = self._create_fastapi_app() |
| 1051 | + |
| 1052 | + return app |
| 1053 | + |
| 1054 | + def test_hooks(self): |
1054 | 1055 | self._client.get("/foobar") |
1055 | | - spans = self.sorted_spans(self.memory_exporter.get_finished_spans()) |
1056 | | - self.assertEqual( |
1057 | | - len(spans), 3 |
1058 | | - ) # 1 server span and 2 response spans (response start and body) |
| 1056 | + |
| 1057 | + spans = cast( |
| 1058 | + list[ReadableSpan], |
| 1059 | + self.sorted_spans(self.memory_exporter.get_finished_spans()), |
| 1060 | + ) |
| 1061 | + self.assertEqual(len(spans), 3) |
1059 | 1062 |
|
1060 | 1063 | server_span = spans[2] |
1061 | 1064 | self.assertEqual(server_span.name, "name from server hook") |
|
0 commit comments