Skip to content

Commit a5e90f2

Browse files
✨ Add x-osparc-trace-id to response headers for requests hitting webserver and api-server (#7796)
1 parent ddc3e74 commit a5e90f2

File tree

9 files changed

+232
-21
lines changed

9 files changed

+232
-21
lines changed

packages/service-library/src/servicelib/aiohttp/tracing.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515
middleware as aiohttp_server_opentelemetry_middleware, # pylint:disable=no-name-in-module
1616
)
1717
from opentelemetry.sdk.resources import Resource
18-
from opentelemetry.sdk.trace import TracerProvider
18+
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
1919
from opentelemetry.sdk.trace.export import BatchSpanProcessor
2020
from servicelib.logging_utils import log_context
21+
from servicelib.tracing import get_trace_id_header
2122
from settings_library.tracing import TracingSettings
2223
from yarl import URL
2324

@@ -51,10 +52,19 @@
5152
HAS_AIO_PIKA = False
5253

5354

55+
def _create_span_processor(tracing_destination: str) -> SpanProcessor:
56+
otlp_exporter = OTLPSpanExporterHTTP(
57+
endpoint=tracing_destination,
58+
)
59+
span_processor = BatchSpanProcessor(otlp_exporter)
60+
return span_processor
61+
62+
5463
def _startup(
5564
app: web.Application,
5665
tracing_settings: TracingSettings,
5766
service_name: str,
67+
add_response_trace_id_header: bool = False,
5868
) -> None:
5969
"""
6070
Sets up this service for a distributed tracing system (opentelemetry)
@@ -90,12 +100,8 @@ def _startup(
90100
tracing_destination,
91101
)
92102

93-
otlp_exporter = OTLPSpanExporterHTTP(
94-
endpoint=tracing_destination,
95-
)
96-
97103
# Add the span processor to the tracer provider
98-
tracer_provider.add_span_processor(BatchSpanProcessor(otlp_exporter)) # type: ignore[attr-defined] # https://github.com/open-telemetry/opentelemetry-python/issues/3713
104+
tracer_provider.add_span_processor(_create_span_processor(tracing_destination)) # type: ignore[attr-defined] # https://github.com/open-telemetry/opentelemetry-python/issues/3713
99105
# Instrument aiohttp server
100106
# Explanation for custom middleware call DK 10/2024:
101107
# OpenTelemetry Aiohttp autoinstrumentation is meant to be used by only calling `AioHttpServerInstrumentor().instrument()`
@@ -106,6 +112,8 @@ def _startup(
106112
#
107113
# Since the code that is provided (monkeypatched) in the __init__ that the opentelemetry-autoinstrumentation-library provides is only 4 lines,
108114
# just adding a middleware, we are free to simply execute this "missed call" [since we can't call the monkeypatch'ed __init__()] in this following line:
115+
if add_response_trace_id_header:
116+
app.middlewares.insert(0, response_trace_id_header_middleware)
109117
app.middlewares.insert(0, aiohttp_server_opentelemetry_middleware)
110118
# Code of the aiohttp server instrumentation: github.com/open-telemetry/opentelemetry-python-contrib/blob/eccb05c808a7d797ef5b6ecefed3590664426fbf/instrumentation/opentelemetry-instrumentation-aiohttp-server/src/opentelemetry/instrumentation/aiohttp_server/__init__.py#L246
111119
# For reference, the above statement was written for:
@@ -146,6 +154,21 @@ def _startup(
146154
AioPikaInstrumentor().instrument()
147155

148156

157+
@web.middleware
158+
async def response_trace_id_header_middleware(request: web.Request, handler):
159+
headers = get_trace_id_header()
160+
161+
try:
162+
response = await handler(request)
163+
except web.HTTPException as exc:
164+
if headers:
165+
exc.headers.update(headers)
166+
raise exc
167+
if headers:
168+
response.headers.update(headers)
169+
return response
170+
171+
149172
def _shutdown() -> None:
150173
"""Uninstruments all opentelemetry instrumentors that were instrumented."""
151174
try:
@@ -175,9 +198,18 @@ def _shutdown() -> None:
175198

176199

177200
def get_tracing_lifespan(
178-
app: web.Application, tracing_settings: TracingSettings, service_name: str
201+
*,
202+
app: web.Application,
203+
tracing_settings: TracingSettings,
204+
service_name: str,
205+
add_response_trace_id_header: bool = False,
179206
) -> Callable[[web.Application], AsyncIterator]:
180-
_startup(app=app, tracing_settings=tracing_settings, service_name=service_name)
207+
_startup(
208+
app=app,
209+
tracing_settings=tracing_settings,
210+
service_name=service_name,
211+
add_response_trace_id_header=add_response_trace_id_header,
212+
)
181213

182214
async def tracing_lifespan(app: web.Application):
183215
assert app # nosec

packages/service-library/src/servicelib/fastapi/tracing.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from collections.abc import AsyncIterator
55

6-
from fastapi import FastAPI
6+
from fastapi import FastAPI, Request
77
from fastapi_lifespan_manager import State
88
from httpx import AsyncClient, Client
99
from opentelemetry import trace
@@ -13,10 +13,12 @@
1313
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1414
from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor
1515
from opentelemetry.sdk.resources import Resource
16-
from opentelemetry.sdk.trace import TracerProvider
16+
from opentelemetry.sdk.trace import SpanProcessor, TracerProvider
1717
from opentelemetry.sdk.trace.export import BatchSpanProcessor
1818
from servicelib.logging_utils import log_context
19+
from servicelib.tracing import get_trace_id_header
1920
from settings_library.tracing import TracingSettings
21+
from starlette.middleware.base import BaseHTTPMiddleware
2022
from yarl import URL
2123

2224
_logger = logging.getLogger(__name__)
@@ -70,6 +72,11 @@
7072
HAS_AIOPIKA_INSTRUMENTOR = False
7173

7274

75+
def _create_span_processor(tracing_destination: str) -> SpanProcessor:
76+
otlp_exporter = OTLPSpanExporterHTTP(endpoint=tracing_destination)
77+
return BatchSpanProcessor(otlp_exporter)
78+
79+
7380
def _startup(tracing_settings: TracingSettings, service_name: str) -> None:
7481
if (
7582
not tracing_settings.TRACING_OPENTELEMETRY_COLLECTOR_ENDPOINT
@@ -96,10 +103,10 @@ def _startup(tracing_settings: TracingSettings, service_name: str) -> None:
96103
service_name,
97104
tracing_destination,
98105
)
99-
# Configure OTLP exporter to send spans to the collector
100-
otlp_exporter = OTLPSpanExporterHTTP(endpoint=tracing_destination)
101-
span_processor = BatchSpanProcessor(otlp_exporter)
102-
global_tracer_provider.add_span_processor(span_processor)
106+
# Add the span processor to the tracer provider
107+
global_tracer_provider.add_span_processor(
108+
_create_span_processor(tracing_destination)
109+
)
103110

104111
if HAS_AIOPG:
105112
with log_context(
@@ -180,7 +187,11 @@ def _shutdown() -> None:
180187
_logger.exception("Failed to uninstrument RequestsInstrumentor")
181188

182189

183-
def initialize_fastapi_app_tracing(app: FastAPI):
190+
def initialize_fastapi_app_tracing(
191+
app: FastAPI, *, add_response_trace_id_header: bool = False
192+
):
193+
if add_response_trace_id_header:
194+
app.add_middleware(ResponseTraceIdHeaderMiddleware)
184195
FastAPIInstrumentor.instrument_app(app)
185196

186197

@@ -216,3 +227,13 @@ async def tracing_instrumentation_lifespan(
216227
_shutdown()
217228

218229
return tracing_instrumentation_lifespan
230+
231+
232+
class ResponseTraceIdHeaderMiddleware(BaseHTTPMiddleware):
233+
234+
async def dispatch(self, request: Request, call_next):
235+
response = await call_next(request)
236+
trace_id_header = get_trace_id_header()
237+
if trace_id_header:
238+
response.headers.update(trace_id_header)
239+
return response

packages/service-library/src/servicelib/tracing.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
TracingContext: TypeAlias = otcontext.Context | None
1010

11+
_OSPARC_TRACE_ID_HEADER = "x-osparc-trace-id"
12+
1113

1214
def _is_tracing() -> bool:
1315
return trace.get_current_span().is_recording()
@@ -34,3 +36,15 @@ def use_tracing_context(context: TracingContext):
3436
def setup_log_tracing(tracing_settings: TracingSettings):
3537
_ = tracing_settings
3638
LoggingInstrumentor().instrument(set_logging_format=False)
39+
40+
41+
def get_trace_id_header() -> dict[str, str] | None:
42+
"""Generates a dictionary containing the trace ID header if tracing is active."""
43+
span = trace.get_current_span()
44+
if span.is_recording():
45+
trace_id = span.get_span_context().trace_id
46+
trace_id_hex = format(
47+
trace_id, "032x"
48+
) # Convert trace_id to 32-character hex string
49+
return {_OSPARC_TRACE_ID_HEADER: trace_id_hex}
50+
return None
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,20 @@
11
# pylint: disable=redefined-outer-name
22
# pylint: disable=unused-argument
3+
4+
5+
from collections.abc import Iterator
6+
7+
import pytest
8+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
9+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
10+
from pytest_mock import MockerFixture
11+
12+
13+
@pytest.fixture
14+
def mock_otel_collector(mocker: MockerFixture) -> Iterator[InMemorySpanExporter]:
15+
memory_exporter = InMemorySpanExporter()
16+
span_processor = SimpleSpanProcessor(memory_exporter)
17+
mocker.patch(
18+
"servicelib.aiohttp.tracing._create_span_processor", return_value=span_processor
19+
)
20+
yield memory_exporter

packages/service-library/tests/aiohttp/test_tracing.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44

55
import importlib
66
from collections.abc import Callable, Iterator
7+
from functools import partial
78
from typing import Any
89

910
import pip
1011
import pytest
1112
from aiohttp import web
1213
from aiohttp.test_utils import TestClient
14+
from opentelemetry import trace
15+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
1316
from pydantic import ValidationError
1417
from servicelib.aiohttp.tracing import get_tracing_lifespan
18+
from servicelib.tracing import _OSPARC_TRACE_ID_HEADER
1519
from settings_library.tracing import TracingSettings
1620

1721

@@ -60,7 +64,7 @@ async def test_valid_tracing_settings(
6064
service_name = "simcore_service_webserver"
6165
tracing_settings = TracingSettings()
6266
async for _ in get_tracing_lifespan(
63-
app, service_name=service_name, tracing_settings=tracing_settings
67+
app=app, service_name=service_name, tracing_settings=tracing_settings
6468
)(app):
6569
pass
6670

@@ -137,14 +141,64 @@ async def test_tracing_setup_package_detection(
137141
service_name = "simcore_service_webserver"
138142
tracing_settings = TracingSettings()
139143
async for _ in get_tracing_lifespan(
140-
app,
144+
app=app,
141145
service_name=service_name,
142146
tracing_settings=tracing_settings,
143147
)(app):
144148
# idempotency
145149
async for _ in get_tracing_lifespan(
146-
app,
150+
app=app,
147151
service_name=service_name,
148152
tracing_settings=tracing_settings,
149153
)(app):
150154
pass
155+
156+
157+
@pytest.mark.parametrize(
158+
"tracing_settings_in",
159+
[
160+
("http://opentelemetry-collector", 4318),
161+
],
162+
indirect=True,
163+
)
164+
@pytest.mark.parametrize(
165+
"server_response", [web.Response(text="Hello, world!"), web.HTTPNotFound()]
166+
)
167+
async def test_trace_id_in_response_header(
168+
mock_otel_collector: InMemorySpanExporter,
169+
aiohttp_client: Callable,
170+
set_and_clean_settings_env_vars: Callable,
171+
tracing_settings_in,
172+
uninstrument_opentelemetry: Iterator[None],
173+
server_response: web.Response | web.HTTPException,
174+
) -> None:
175+
app = web.Application()
176+
service_name = "simcore_service_webserver"
177+
tracing_settings = TracingSettings()
178+
179+
async def handler(handler_data: dict, request: web.Request) -> web.Response:
180+
current_span = trace.get_current_span()
181+
handler_data[_OSPARC_TRACE_ID_HEADER] = format(
182+
current_span.get_span_context().trace_id, "032x"
183+
)
184+
if isinstance(server_response, web.HTTPException):
185+
raise server_response
186+
return server_response
187+
188+
handler_data = dict()
189+
app.router.add_get("/", partial(handler, handler_data))
190+
191+
async for _ in get_tracing_lifespan(
192+
app=app,
193+
service_name=service_name,
194+
tracing_settings=tracing_settings,
195+
add_response_trace_id_header=True,
196+
)(app):
197+
client = await aiohttp_client(app)
198+
response = await client.get("/")
199+
assert _OSPARC_TRACE_ID_HEADER in response.headers
200+
trace_id = response.headers[_OSPARC_TRACE_ID_HEADER]
201+
assert len(trace_id) == 32 # Ensure trace ID is a 32-character hex string
202+
assert (
203+
trace_id == handler_data[_OSPARC_TRACE_ID_HEADER]
204+
) # Ensure trace IDs match

packages/service-library/tests/fastapi/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
# pylint: disable=unused-variable
44

55
import socket
6-
from collections.abc import AsyncIterator, Callable
6+
from collections.abc import AsyncIterator, Callable, Iterator
77
from typing import cast
88

99
import arrow
1010
import pytest
1111
from fastapi import APIRouter, FastAPI
1212
from fastapi.params import Query
1313
from httpx import ASGITransport, AsyncClient
14+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
15+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
1416
from pydantic.types import PositiveFloat
17+
from pytest_mock import MockerFixture
1518

1619

1720
@pytest.fixture
@@ -55,3 +58,13 @@ def go() -> int:
5558
return cast(int, s.getsockname()[1])
5659

5760
return go
61+
62+
63+
@pytest.fixture
64+
def mock_otel_collector(mocker: MockerFixture) -> Iterator[InMemorySpanExporter]:
65+
memory_exporter = InMemorySpanExporter()
66+
span_processor = SimpleSpanProcessor(memory_exporter)
67+
mocker.patch(
68+
"servicelib.fastapi.tracing._create_span_processor", return_value=span_processor
69+
)
70+
yield memory_exporter

0 commit comments

Comments
 (0)