diff --git a/instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py b/instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py index c84839deb7..483211959f 100644 --- a/instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-aiohttp-client/src/opentelemetry/instrumentation/aiohttp_client/__init__.py @@ -88,8 +88,10 @@ def response_hook(span: Span, params: typing.Union[ --- """ +import inspect import types import typing +from collections.abc import Awaitable from timeit import default_timer from typing import Collection from urllib.parse import urlparse @@ -139,7 +141,10 @@ def response_hook(span: Span, params: typing.Union[ _UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]] _RequestHookT = typing.Optional[ - typing.Callable[[Span, aiohttp.TraceRequestStartParams], None] + typing.Callable[ + [Span, aiohttp.TraceRequestStartParams], + typing.Union[None, Awaitable[None]], + ] ] _ResponseHookT = typing.Optional[ typing.Callable[ @@ -150,7 +155,7 @@ def response_hook(span: Span, params: typing.Union[ aiohttp.TraceRequestExceptionParams, ], ], - None, + typing.Union[None, Awaitable[None]], ] ] @@ -367,7 +372,9 @@ async def on_request_start( ) if callable(request_hook): - request_hook(trace_config_ctx.span, params) + call = request_hook(trace_config_ctx.span, params) + if inspect.isawaitable(call): + await call trace_config_ctx.token = context_api.attach( trace.set_span_in_context(trace_config_ctx.span) @@ -384,7 +391,10 @@ async def on_request_end( return if callable(response_hook): - response_hook(trace_config_ctx.span, params) + call = response_hook(trace_config_ctx.span, params) + if inspect.isawaitable(call): + await call + _set_http_status_code_attribute( trace_config_ctx.span, params.response.status, @@ -414,7 +424,9 @@ async def on_request_exception( trace_config_ctx.span.record_exception(params.exception) if callable(response_hook): - response_hook(trace_config_ctx.span, params) + call = response_hook(trace_config_ctx.span, params) + if inspect.isawaitable(call): + await call _end_trace(trace_config_ctx) diff --git a/instrumentation/opentelemetry-instrumentation-aiohttp-client/tests/test_aiohttp_client_integration.py b/instrumentation/opentelemetry-instrumentation-aiohttp-client/tests/test_aiohttp_client_integration.py index ec608e6a67..de5ebbc4fd 100644 --- a/instrumentation/opentelemetry-instrumentation-aiohttp-client/tests/test_aiohttp_client_integration.py +++ b/instrumentation/opentelemetry-instrumentation-aiohttp-client/tests/test_aiohttp_client_integration.py @@ -20,7 +20,7 @@ import unittest import urllib.parse from http import HTTPStatus -from unittest import mock +from unittest import mock, IsolatedAsyncioTestCase import aiohttp import aiohttp.test_utils @@ -87,7 +87,35 @@ async def do_request(): return loop.run_until_complete(do_request()) -class TestAioHttpIntegration(TestBase): +class HttpRequestMixin: + @staticmethod + def _http_request( + trace_config, + url: str, + method: str = "GET", + status_code: int = HTTPStatus.OK, + request_handler: typing.Callable = None, + **kwargs, + ) -> typing.Tuple[str, int]: + """Helper to start an aiohttp test server and send an actual HTTP request to it.""" + + async def default_handler(request): + assert "traceparent" in request.headers + return aiohttp.web.Response(status=int(status_code)) + + async def client_request(server: aiohttp.test_utils.TestServer): + async with aiohttp.test_utils.TestClient( + server, trace_configs=[trace_config] + ) as client: + await client.request( + method, url, trace_request_ctx={}, **kwargs + ) + + handler = request_handler or default_handler + return run_with_test_server(client_request, url, handler) + + +class TestAioHttpIntegration(TestBase, HttpRequestMixin): _test_status_codes = ( (HTTPStatus.OK, StatusCode.UNSET), (HTTPStatus.TEMPORARY_REDIRECT, StatusCode.UNSET), @@ -121,32 +149,6 @@ def _assert_metrics(self, num_metrics: int = 1): self.assertEqual(len(metrics), num_metrics) return metrics - @staticmethod - def _http_request( - trace_config, - url: str, - method: str = "GET", - status_code: int = HTTPStatus.OK, - request_handler: typing.Callable = None, - **kwargs, - ) -> typing.Tuple[str, int]: - """Helper to start an aiohttp test server and send an actual HTTP request to it.""" - - async def default_handler(request): - assert "traceparent" in request.headers - return aiohttp.web.Response(status=int(status_code)) - - async def client_request(server: aiohttp.test_utils.TestServer): - async with aiohttp.test_utils.TestClient( - server, trace_configs=[trace_config] - ) as client: - await client.request( - method, url, trace_request_ctx={}, **kwargs - ) - - handler = request_handler or default_handler - return run_with_test_server(client_request, url, handler) - def test_status_codes(self): index = 0 for status_code, span_status in self._test_status_codes: @@ -804,6 +806,51 @@ async def do_request(url): self.memory_exporter.clear() +class TestAioHttpIntegrationAsync(TestBase, IsolatedAsyncioTestCase, HttpRequestMixin): + def test_async_hooks(self): + method = "PATCH" + path = "/some/path" + expected = "PATCH - /some/path" + + async def request_hook(span: Span, params: aiohttp.TraceRequestStartParams): + span.update_name(f"{params.method} - {params.url.path}") + + async def response_hook( + span: Span, + params: typing.Union[ + aiohttp.TraceRequestEndParams, + aiohttp.TraceRequestExceptionParams, + ], + ): + span.set_attribute("response_hook_attr", "value") + + host, port = self._http_request( + trace_config=aiohttp_client.create_trace_config( + request_hook=request_hook, + response_hook=response_hook, + ), + method=method, + url=path, + status_code=HTTPStatus.OK, + ) + + for span in self.memory_exporter.get_finished_spans(): + self.assertEqual(span.name, expected) + self.assertEqual( + (span.status.status_code, span.status.description), + (StatusCode.UNSET, None), + ) + self.assertEqual(span.attributes[HTTP_METHOD], method) + self.assertEqual( + span.attributes[HTTP_URL], + f"http://{host}:{port}{path}", + ) + self.assertEqual(span.attributes[HTTP_STATUS_CODE], HTTPStatus.OK) + self.assertIn("response_hook_attr", span.attributes) + self.assertEqual(span.attributes["response_hook_attr"], "value") + self.memory_exporter.clear() + + class TestAioHttpClientInstrumentor(TestBase): URL = "/test-path"