diff --git a/packages/service-library/src/servicelib/aiohttp/tracing.py b/packages/service-library/src/servicelib/aiohttp/tracing.py index 1e41aab20f0..d4e2d3610cd 100644 --- a/packages/service-library/src/servicelib/aiohttp/tracing.py +++ b/packages/service-library/src/servicelib/aiohttp/tracing.py @@ -15,9 +15,10 @@ middleware as aiohttp_server_opentelemetry_middleware, # pylint:disable=no-name-in-module ) from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from servicelib.logging_utils import log_context +from servicelib.tracing import get_trace_id_header from settings_library.tracing import TracingSettings from yarl import URL @@ -51,10 +52,19 @@ HAS_AIO_PIKA = False +def _create_span_processor(tracing_destination: str) -> SpanProcessor: + otlp_exporter = OTLPSpanExporterHTTP( + endpoint=tracing_destination, + ) + span_processor = BatchSpanProcessor(otlp_exporter) + return span_processor + + def _startup( app: web.Application, tracing_settings: TracingSettings, service_name: str, + add_response_trace_id_header: bool = False, ) -> None: """ Sets up this service for a distributed tracing system (opentelemetry) @@ -90,12 +100,8 @@ def _startup( tracing_destination, ) - otlp_exporter = OTLPSpanExporterHTTP( - endpoint=tracing_destination, - ) - # Add the span processor to the tracer provider - tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter)) # type: ignore[attr-defined] # https://github.com/open-telemetry/opentelemetry-python/issues/3713 + tracer_provider.add_span_processor(_create_span_processor(tracing_destination)) # type: ignore[attr-defined] # https://github.com/open-telemetry/opentelemetry-python/issues/3713 # Instrument aiohttp server # Explanation for custom middleware call DK 10/2024: # OpenTelemetry Aiohttp autoinstrumentation is meant to be used by only calling `AioHttpServerInstrumentor().instrument()` @@ -106,6 +112,8 @@ def _startup( # # Since the code that is provided (monkeypatched) in the __init__ that the opentelemetry-autoinstrumentation-library provides is only 4 lines, # just adding a middleware, we are free to simply execute this "missed call" [since we can't call the monkeypatch'ed __init__()] in this following line: + if add_response_trace_id_header: + app.middlewares.insert(0, response_trace_id_header_middleware) app.middlewares.insert(0, aiohttp_server_opentelemetry_middleware) # Code of the aiohttp server instrumentation: github.com/open-telemetry/opentelemetry-python-contrib/blob/eccb05c808a7d797ef5b6ecefed3590664426fbf/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py#L246 # For reference, the above statement was written for: @@ -146,6 +154,21 @@ def _startup( AioPikaInstrumentor().instrument() +@web.middleware +async def response_trace_id_header_middleware(request: web.Request, handler): + headers = get_trace_id_header() + + try: + response = await handler(request) + except web.HTTPException as exc: + if headers: + exc.headers.update(headers) + raise exc + if headers: + response.headers.update(headers) + return response + + def _shutdown() -> None: """Uninstruments all opentelemetry instrumentors that were instrumented.""" try: @@ -175,9 +198,18 @@ def _shutdown() -> None: def get_tracing_lifespan( - app: web.Application, tracing_settings: TracingSettings, service_name: str + *, + app: web.Application, + tracing_settings: TracingSettings, + service_name: str, + add_response_trace_id_header: bool = False, ) -> Callable[[web.Application], AsyncIterator]: - _startup(app=app, tracing_settings=tracing_settings, service_name=service_name) + _startup( + app=app, + tracing_settings=tracing_settings, + service_name=service_name, + add_response_trace_id_header=add_response_trace_id_header, + ) async def tracing_lifespan(app: web.Application): assert app # nosec diff --git a/packages/service-library/src/servicelib/fastapi/tracing.py b/packages/service-library/src/servicelib/fastapi/tracing.py index 5b2cba5434d..06224618a63 100644 --- a/packages/service-library/src/servicelib/fastapi/tracing.py +++ b/packages/service-library/src/servicelib/fastapi/tracing.py @@ -3,7 +3,7 @@ import logging from collections.abc import AsyncIterator -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi_lifespan_manager import State from httpx import AsyncClient, Client from opentelemetry import trace @@ -13,10 +13,12 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import SpanProcessor, TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from servicelib.logging_utils import log_context +from servicelib.tracing import get_trace_id_header from settings_library.tracing import TracingSettings +from starlette.middleware.base import BaseHTTPMiddleware from yarl import URL _logger = logging.getLogger(__name__) @@ -70,6 +72,11 @@ HAS_AIOPIKA_INSTRUMENTOR = False +def _create_span_processor(tracing_destination: str) -> SpanProcessor: + otlp_exporter = OTLPSpanExporterHTTP(endpoint=tracing_destination) + return BatchSpanProcessor(otlp_exporter) + + def _startup(tracing_settings: TracingSettings, service_name: str) -> None: if ( not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT @@ -96,10 +103,10 @@ def _startup(tracing_settings: TracingSettings, service_name: str) -> None: service_name, tracing_destination, ) - # Configure OTLP exporter to send spans to the collector - otlp_exporter = OTLPSpanExporterHTTP(endpoint=tracing_destination) - span_processor = BatchSpanProcessor(otlp_exporter) - global_tracer_provider.add_span_processor(span_processor) + # Add the span processor to the tracer provider + global_tracer_provider.add_span_processor( + _create_span_processor(tracing_destination) + ) if HAS_AIOPG: with log_context( @@ -180,7 +187,11 @@ def _shutdown() -> None: _logger.exception("Failed to uninstrument RequestsInstrumentor") -def initialize_fastapi_app_tracing(app: FastAPI): +def initialize_fastapi_app_tracing( + app: FastAPI, *, add_response_trace_id_header: bool = False +): + if add_response_trace_id_header: + app.add_middleware(ResponseTraceIdHeaderMiddleware) FastAPIInstrumentor.instrument_app(app) @@ -216,3 +227,13 @@ async def tracing_instrumentation_lifespan( _shutdown() return tracing_instrumentation_lifespan + + +class ResponseTraceIdHeaderMiddleware(BaseHTTPMiddleware): + + async def dispatch(self, request: Request, call_next): + response = await call_next(request) + trace_id_header = get_trace_id_header() + if trace_id_header: + response.headers.update(trace_id_header) + return response diff --git a/packages/service-library/src/servicelib/tracing.py b/packages/service-library/src/servicelib/tracing.py index e1b3b348a72..88ed5d4c30c 100644 --- a/packages/service-library/src/servicelib/tracing.py +++ b/packages/service-library/src/servicelib/tracing.py @@ -8,6 +8,8 @@ TracingContext: TypeAlias = otcontext.Context | None +_OSPARC_TRACE_ID_HEADER = "x-osparc-trace-id" + def _is_tracing() -> bool: return trace.get_current_span().is_recording() @@ -34,3 +36,15 @@ def use_tracing_context(context: TracingContext): def setup_log_tracing(tracing_settings: TracingSettings): _ = tracing_settings LoggingInstrumentor().instrument(set_logging_format=False) + + +def get_trace_id_header() -> dict[str, str] | None: + """Generates a dictionary containing the trace ID header if tracing is active.""" + span = trace.get_current_span() + if span.is_recording(): + trace_id = span.get_span_context().trace_id + trace_id_hex = format( + trace_id, "032x" + ) # Convert trace_id to 32-character hex string + return {_OSPARC_TRACE_ID_HEADER: trace_id_hex} + return None diff --git a/packages/service-library/tests/aiohttp/conftest.py b/packages/service-library/tests/aiohttp/conftest.py index 1891ee17d15..f2055a80b0c 100644 --- a/packages/service-library/tests/aiohttp/conftest.py +++ b/packages/service-library/tests/aiohttp/conftest.py @@ -1,2 +1,20 @@ # pylint: disable=redefined-outer-name # pylint: disable=unused-argument + + +from collections.abc import Iterator + +import pytest +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from pytest_mock import MockerFixture + + +@pytest.fixture +def mock_otel_collector(mocker: MockerFixture) -> Iterator[InMemorySpanExporter]: + memory_exporter = InMemorySpanExporter() + span_processor = SimpleSpanProcessor(memory_exporter) + mocker.patch( + "servicelib.aiohttp.tracing._create_span_processor", return_value=span_processor + ) + yield memory_exporter diff --git a/packages/service-library/tests/aiohttp/test_tracing.py b/packages/service-library/tests/aiohttp/test_tracing.py index 2621751f344..8e297427923 100644 --- a/packages/service-library/tests/aiohttp/test_tracing.py +++ b/packages/service-library/tests/aiohttp/test_tracing.py @@ -4,14 +4,18 @@ import importlib from collections.abc import Callable, Iterator +from functools import partial from typing import Any import pip import pytest from aiohttp import web from aiohttp.test_utils import TestClient +from opentelemetry import trace +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from pydantic import ValidationError from servicelib.aiohttp.tracing import get_tracing_lifespan +from servicelib.tracing import _OSPARC_TRACE_ID_HEADER from settings_library.tracing import TracingSettings @@ -60,7 +64,7 @@ async def test_valid_tracing_settings( service_name = "simcore_service_webserver" tracing_settings = TracingSettings() async for _ in get_tracing_lifespan( - app, service_name=service_name, tracing_settings=tracing_settings + app=app, service_name=service_name, tracing_settings=tracing_settings )(app): pass @@ -137,14 +141,64 @@ async def test_tracing_setup_package_detection( service_name = "simcore_service_webserver" tracing_settings = TracingSettings() async for _ in get_tracing_lifespan( - app, + app=app, service_name=service_name, tracing_settings=tracing_settings, )(app): # idempotency async for _ in get_tracing_lifespan( - app, + app=app, service_name=service_name, tracing_settings=tracing_settings, )(app): pass + + +@pytest.mark.parametrize( + "tracing_settings_in", + [ + ("http://opentelemetry-collector", 4318), + ], + indirect=True, +) +@pytest.mark.parametrize( + "server_response", [web.Response(text="Hello, world!"), web.HTTPNotFound()] +) +async def test_trace_id_in_response_header( + mock_otel_collector: InMemorySpanExporter, + aiohttp_client: Callable, + set_and_clean_settings_env_vars: Callable, + tracing_settings_in, + uninstrument_opentelemetry: Iterator[None], + server_response: web.Response | web.HTTPException, +) -> None: + app = web.Application() + service_name = "simcore_service_webserver" + tracing_settings = TracingSettings() + + async def handler(handler_data: dict, request: web.Request) -> web.Response: + current_span = trace.get_current_span() + handler_data[_OSPARC_TRACE_ID_HEADER] = format( + current_span.get_span_context().trace_id, "032x" + ) + if isinstance(server_response, web.HTTPException): + raise server_response + return server_response + + handler_data = dict() + app.router.add_get("/", partial(handler, handler_data)) + + async for _ in get_tracing_lifespan( + app=app, + service_name=service_name, + tracing_settings=tracing_settings, + add_response_trace_id_header=True, + )(app): + client = await aiohttp_client(app) + response = await client.get("/") + assert _OSPARC_TRACE_ID_HEADER in response.headers + trace_id = response.headers[_OSPARC_TRACE_ID_HEADER] + assert len(trace_id) == 32 # Ensure trace ID is a 32-character hex string + assert ( + trace_id == handler_data[_OSPARC_TRACE_ID_HEADER] + ) # Ensure trace IDs match diff --git a/packages/service-library/tests/fastapi/conftest.py b/packages/service-library/tests/fastapi/conftest.py index f8811ca04f5..66db13b8664 100644 --- a/packages/service-library/tests/fastapi/conftest.py +++ b/packages/service-library/tests/fastapi/conftest.py @@ -3,7 +3,7 @@ # pylint: disable=unused-variable import socket -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncIterator, Callable, Iterator from typing import cast import arrow @@ -11,7 +11,10 @@ from fastapi import APIRouter, FastAPI from fastapi.params import Query from httpx import ASGITransport, AsyncClient +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from pydantic.types import PositiveFloat +from pytest_mock import MockerFixture @pytest.fixture @@ -55,3 +58,13 @@ def go() -> int: return cast(int, s.getsockname()[1]) return go + + +@pytest.fixture +def mock_otel_collector(mocker: MockerFixture) -> Iterator[InMemorySpanExporter]: + memory_exporter = InMemorySpanExporter() + span_processor = SimpleSpanProcessor(memory_exporter) + mocker.patch( + "servicelib.fastapi.tracing._create_span_processor", return_value=span_processor + ) + yield memory_exporter diff --git a/packages/service-library/tests/fastapi/test_tracing.py b/packages/service-library/tests/fastapi/test_tracing.py index 8e58dfd75dd..16becc3a1b6 100644 --- a/packages/service-library/tests/fastapi/test_tracing.py +++ b/packages/service-library/tests/fastapi/test_tracing.py @@ -5,15 +5,23 @@ import random import string from collections.abc import Callable, Iterator +from functools import partial from typing import Any import pip import pytest from fastapi import FastAPI +from fastapi.exceptions import HTTPException +from fastapi.responses import PlainTextResponse +from fastapi.testclient import TestClient +from opentelemetry import trace +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from pydantic import ValidationError from servicelib.fastapi.tracing import ( get_tracing_instrumentation_lifespan, + initialize_fastapi_app_tracing, ) +from servicelib.tracing import _OSPARC_TRACE_ID_HEADER from settings_library.tracing import TracingSettings @@ -167,3 +175,53 @@ async def test_tracing_setup_package_detection( service_name="Mock-Openetlemetry-Pytest", )(app=mocked_app): pass + + +@pytest.mark.parametrize( + "tracing_settings_in", + [ + ("http://opentelemetry-collector", 4318), + ], + indirect=True, +) +@pytest.mark.parametrize( + "server_response", + [ + PlainTextResponse("ok"), + HTTPException(status_code=400, detail="error"), + ], +) +async def test_trace_id_in_response_header( + mock_otel_collector: InMemorySpanExporter, + mocked_app: FastAPI, + set_and_clean_settings_env_vars: Callable, + tracing_settings_in: Callable, + uninstrument_opentelemetry: Iterator[None], + server_response: PlainTextResponse | HTTPException, +) -> None: + tracing_settings = TracingSettings() + + handler_data = dict() + + async def handler(handler_data: dict): + current_span = trace.get_current_span() + handler_data[_OSPARC_TRACE_ID_HEADER] = format( + current_span.get_span_context().trace_id, "032x" + ) + if isinstance(server_response, HTTPException): + raise server_response + return server_response + + mocked_app.get("/")(partial(handler, handler_data)) + + async for _ in get_tracing_instrumentation_lifespan( + tracing_settings=tracing_settings, + service_name="Mock-OpenTelemetry-Pytest", + )(app=mocked_app): + initialize_fastapi_app_tracing(mocked_app, add_response_trace_id_header=True) + client = TestClient(mocked_app) + response = client.get("/") + assert _OSPARC_TRACE_ID_HEADER in response.headers + trace_id = response.headers[_OSPARC_TRACE_ID_HEADER] + assert len(trace_id) == 32 # Ensure trace ID is a 32-character hex string + assert trace_id == handler_data[_OSPARC_TRACE_ID_HEADER] diff --git a/services/api-server/src/simcore_service_api_server/core/application.py b/services/api-server/src/simcore_service_api_server/core/application.py index 44c5b5fc129..145ff8efb77 100644 --- a/services/api-server/src/simcore_service_api_server/core/application.py +++ b/services/api-server/src/simcore_service_api_server/core/application.py @@ -97,7 +97,7 @@ def init_app(settings: ApplicationSettings | None = None) -> FastAPI: setup_prometheus_instrumentation(app) if settings.API_SERVER_TRACING: - initialize_fastapi_app_tracing(app) + initialize_fastapi_app_tracing(app, add_response_trace_id_header=True) if settings.API_SERVER_WEBSERVER: webserver.setup( diff --git a/services/web/server/src/simcore_service_webserver/tracing.py b/services/web/server/src/simcore_service_webserver/tracing.py index ffbb8f404a0..0c18954dcb9 100644 --- a/services/web/server/src/simcore_service_webserver/tracing.py +++ b/services/web/server/src/simcore_service_webserver/tracing.py @@ -25,8 +25,9 @@ def setup_app_tracing(app: web.Application): tracing_settings: TracingSettings = get_plugin_settings(app) app.cleanup_ctx.append( get_tracing_lifespan( - app, + app=app, tracing_settings=tracing_settings, service_name=APP_NAME, + add_response_trace_id_header=True, ) )