Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ def response_hook(span: Span, params: typing.Union[
---
"""

import inspect
import types
import typing
from collections.abc import Awaitable
from timeit import default_timer
from typing import Collection
from urllib.parse import urlparse
Expand Down Expand Up @@ -139,7 +141,10 @@ def response_hook(span: Span, params: typing.Union[

_UrlFilterT = typing.Optional[typing.Callable[[yarl.URL], str]]
_RequestHookT = typing.Optional[
typing.Callable[[Span, aiohttp.TraceRequestStartParams], None]
typing.Callable[
[Span, aiohttp.TraceRequestStartParams],
typing.Union[None, Awaitable[None]],
]
]
_ResponseHookT = typing.Optional[
typing.Callable[
Expand All @@ -150,7 +155,7 @@ def response_hook(span: Span, params: typing.Union[
aiohttp.TraceRequestExceptionParams,
],
],
None,
typing.Union[None, Awaitable[None]],
]
]

Expand Down Expand Up @@ -367,7 +372,9 @@ async def on_request_start(
)

if callable(request_hook):
request_hook(trace_config_ctx.span, params)
call = request_hook(trace_config_ctx.span, params)
if inspect.isawaitable(call):
await call

trace_config_ctx.token = context_api.attach(
trace.set_span_in_context(trace_config_ctx.span)
Expand All @@ -384,7 +391,10 @@ async def on_request_end(
return

if callable(response_hook):
response_hook(trace_config_ctx.span, params)
call = response_hook(trace_config_ctx.span, params)
if inspect.isawaitable(call):
await call

_set_http_status_code_attribute(
trace_config_ctx.span,
params.response.status,
Expand Down Expand Up @@ -414,7 +424,9 @@ async def on_request_exception(
trace_config_ctx.span.record_exception(params.exception)

if callable(response_hook):
response_hook(trace_config_ctx.span, params)
call = response_hook(trace_config_ctx.span, params)
if inspect.isawaitable(call):
await call

_end_trace(trace_config_ctx)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import unittest
import urllib.parse
from http import HTTPStatus
from unittest import mock
from unittest import mock, IsolatedAsyncioTestCase

import aiohttp
import aiohttp.test_utils
Expand Down Expand Up @@ -87,7 +87,35 @@ async def do_request():
return loop.run_until_complete(do_request())


class TestAioHttpIntegration(TestBase):
class HttpRequestMixin:
@staticmethod
def _http_request(
trace_config,
url: str,
method: str = "GET",
status_code: int = HTTPStatus.OK,
request_handler: typing.Callable = None,
**kwargs,
) -> typing.Tuple[str, int]:
"""Helper to start an aiohttp test server and send an actual HTTP request to it."""

async def default_handler(request):
assert "traceparent" in request.headers
return aiohttp.web.Response(status=int(status_code))

async def client_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(
server, trace_configs=[trace_config]
) as client:
await client.request(
method, url, trace_request_ctx={}, **kwargs
)

handler = request_handler or default_handler
return run_with_test_server(client_request, url, handler)


class TestAioHttpIntegration(TestBase, HttpRequestMixin):
_test_status_codes = (
(HTTPStatus.OK, StatusCode.UNSET),
(HTTPStatus.TEMPORARY_REDIRECT, StatusCode.UNSET),
Expand Down Expand Up @@ -121,32 +149,6 @@ def _assert_metrics(self, num_metrics: int = 1):
self.assertEqual(len(metrics), num_metrics)
return metrics

@staticmethod
def _http_request(
trace_config,
url: str,
method: str = "GET",
status_code: int = HTTPStatus.OK,
request_handler: typing.Callable = None,
**kwargs,
) -> typing.Tuple[str, int]:
"""Helper to start an aiohttp test server and send an actual HTTP request to it."""

async def default_handler(request):
assert "traceparent" in request.headers
return aiohttp.web.Response(status=int(status_code))

async def client_request(server: aiohttp.test_utils.TestServer):
async with aiohttp.test_utils.TestClient(
server, trace_configs=[trace_config]
) as client:
await client.request(
method, url, trace_request_ctx={}, **kwargs
)

handler = request_handler or default_handler
return run_with_test_server(client_request, url, handler)

def test_status_codes(self):
index = 0
for status_code, span_status in self._test_status_codes:
Expand Down Expand Up @@ -804,6 +806,51 @@ async def do_request(url):
self.memory_exporter.clear()


class TestAioHttpIntegrationAsync(TestBase, IsolatedAsyncioTestCase, HttpRequestMixin):
def test_async_hooks(self):
method = "PATCH"
path = "/some/path"
expected = "PATCH - /some/path"

async def request_hook(span: Span, params: aiohttp.TraceRequestStartParams):
span.update_name(f"{params.method} - {params.url.path}")

async def response_hook(
span: Span,
params: typing.Union[
aiohttp.TraceRequestEndParams,
aiohttp.TraceRequestExceptionParams,
],
):
span.set_attribute("response_hook_attr", "value")

host, port = self._http_request(
trace_config=aiohttp_client.create_trace_config(
request_hook=request_hook,
response_hook=response_hook,
),
method=method,
url=path,
status_code=HTTPStatus.OK,
)

for span in self.memory_exporter.get_finished_spans():
self.assertEqual(span.name, expected)
self.assertEqual(
(span.status.status_code, span.status.description),
(StatusCode.UNSET, None),
)
self.assertEqual(span.attributes[HTTP_METHOD], method)
self.assertEqual(
span.attributes[HTTP_URL],
f"http://{host}:{port}{path}",
)
self.assertEqual(span.attributes[HTTP_STATUS_CODE], HTTPStatus.OK)
self.assertIn("response_hook_attr", span.attributes)
self.assertEqual(span.attributes["response_hook_attr"], "value")
self.memory_exporter.clear()


class TestAioHttpClientInstrumentor(TestBase):
URL = "/test-path"

Expand Down