Skip to content

Commit 1f3744b

Browse files
committed
checkpoint: otel platform provider trace export with module-level URL config
1 parent 618db08 commit 1f3744b

File tree

3 files changed

+152
-72
lines changed

3 files changed

+152
-72
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .platform import PlatformProvider
2-
from .utils import export_completion_trace
2+
from .utils import export_completion_trace, shutdown_telemetry
33

4-
__all__ = ["PlatformProvider", "export_completion_trace"]
4+
__all__ = ["PlatformProvider", "export_completion_trace", "shutdown_telemetry"]

src/any_llm/providers/platform/utils.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3+
import atexit
34
import os
5+
import threading
46
from typing import TYPE_CHECKING
57

68
import httpx # noqa: TC002
@@ -10,8 +12,8 @@
1012
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
1113
from opentelemetry.sdk.resources import Resource
1214
from opentelemetry.sdk.trace import TracerProvider
15+
from opentelemetry.sdk.trace.export import BatchSpanProcessor
1316
from opentelemetry.trace import SpanKind
14-
from opentelemetry.sdk.trace.export import SimpleSpanProcessor, SpanExporter
1517

1618
from any_llm import __version__
1719

@@ -26,21 +28,47 @@
2628
_trace_base_url = _trace_base_url[: -len("/api/v1")]
2729
ANY_LLM_PLATFORM_TRACE_URL = f"{_trace_base_url}{TRACE_API_PATH}"
2830

31+
# Module-level cache: access_token -> TracerProvider
32+
_providers: dict[str, TracerProvider] = {}
33+
_providers_lock = threading.Lock()
2934

30-
def _build_span_exporter(access_token: str) -> SpanExporter:
31-
return OTLPSpanExporter(
32-
endpoint=ANY_LLM_PLATFORM_TRACE_URL,
33-
headers={
34-
"Authorization": f"Bearer {access_token}",
35-
"User-Agent": f"python-any-llm/{__version__}",
36-
},
37-
)
3835

36+
def _get_or_create_tracer_provider(access_token: str) -> TracerProvider:
37+
"""Get or create a TracerProvider for the given access token.
3938
40-
def _build_tracer_provider(access_token: str) -> TracerProvider:
41-
provider = TracerProvider(resource=Resource.create({"service.name": "any-llm"}))
42-
provider.add_span_processor(SimpleSpanProcessor(_build_span_exporter(access_token)))
43-
return provider
39+
Providers are cached by token and reused across requests.
40+
When a token expires and a new one is issued, a new provider is created.
41+
"""
42+
with _providers_lock:
43+
if access_token in _providers:
44+
return _providers[access_token]
45+
46+
provider = TracerProvider(resource=Resource.create({"service.name": "any-llm"}))
47+
exporter = OTLPSpanExporter(
48+
endpoint=ANY_LLM_PLATFORM_TRACE_URL,
49+
headers={
50+
"Authorization": f"Bearer {access_token}",
51+
"User-Agent": f"python-any-llm/{__version__}",
52+
},
53+
)
54+
provider.add_span_processor(BatchSpanProcessor(exporter))
55+
_providers[access_token] = provider
56+
return provider
57+
58+
59+
def shutdown_telemetry() -> None:
60+
"""Shutdown all cached tracer providers.
61+
62+
Called automatically at process exit via atexit, but can also be called
63+
manually to ensure all pending spans are flushed before shutdown.
64+
"""
65+
with _providers_lock:
66+
for provider in _providers.values():
67+
provider.shutdown()
68+
_providers.clear()
69+
70+
71+
atexit.register(shutdown_telemetry)
4472

4573

4674
async def export_completion_trace(
@@ -70,7 +98,7 @@ async def export_completion_trace(
7098
"""
7199
access_token = await platform_client._aensure_valid_token(any_llm_key)
72100

73-
provider_instance = _build_tracer_provider(access_token)
101+
provider_instance = _get_or_create_tracer_provider(access_token)
74102
tracer = provider_instance.get_tracer("any-llm", __version__)
75103

76104
span = tracer.start_span("llm.request", kind=SpanKind.CLIENT, start_time=start_time_ns)
@@ -111,5 +139,3 @@ async def export_completion_trace(
111139
)
112140

113141
span.end(end_time=end_time_ns)
114-
provider_instance.force_flush()
115-
provider_instance.shutdown()

tests/unit/providers/test_platform_provider.py

Lines changed: 108 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import httpx
88
import pytest
99
from any_llm_platform_client import DecryptedProviderKey
10+
from opentelemetry.sdk.trace import TracerProvider
11+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
1012
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
1113
from pydantic import ValidationError
1214

@@ -494,14 +496,12 @@ async def test_export_completion_trace_success(
494496
any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
495497

496498
exporter = InMemorySpanExporter()
497-
498-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
499-
assert access_token == "mock-jwt-token-12345"
500-
return exporter
499+
test_provider = TracerProvider()
500+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
501501

502502
client = AsyncMock(spec=httpx.AsyncClient)
503503

504-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
504+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
505505
await export_completion_trace(
506506
platform_client=mock_platform_client,
507507
client=client,
@@ -537,14 +537,12 @@ async def test_export_completion_trace_with_client_name(
537537
client_name = "my-test-client"
538538

539539
exporter = InMemorySpanExporter()
540-
541-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
542-
assert access_token == "mock-jwt-token-12345"
543-
return exporter
540+
test_provider = TracerProvider()
541+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
544542

545543
client = AsyncMock(spec=httpx.AsyncClient)
546544

547-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
545+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
548546
await export_completion_trace(
549547
platform_client=mock_platform_client,
550548
client=client,
@@ -890,14 +888,12 @@ async def test_export_completion_trace_with_performance_metrics(
890888
any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
891889

892890
exporter = InMemorySpanExporter()
891+
test_provider = TracerProvider()
892+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
893893

894894
client = AsyncMock(spec=httpx.AsyncClient)
895895

896-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
897-
assert access_token == "mock-jwt-token-12345"
898-
return exporter
899-
900-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
896+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
901897
await export_completion_trace(
902898
platform_client=mock_platform_client,
903899
client=client,
@@ -937,14 +933,12 @@ async def test_export_completion_trace_with_partial_performance_metrics(
937933
any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
938934

939935
exporter = InMemorySpanExporter()
936+
test_provider = TracerProvider()
937+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
940938

941939
client = AsyncMock(spec=httpx.AsyncClient)
942940

943-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
944-
assert access_token == "mock-jwt-token-12345"
945-
return exporter
946-
947-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
941+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
948942
await export_completion_trace(
949943
platform_client=mock_platform_client,
950944
client=client,
@@ -975,14 +969,12 @@ async def test_export_completion_trace_without_performance_metrics(
975969
any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
976970

977971
exporter = InMemorySpanExporter()
972+
test_provider = TracerProvider()
973+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
978974

979975
client = AsyncMock(spec=httpx.AsyncClient)
980976

981-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
982-
assert access_token == "mock-jwt-token-12345"
983-
return exporter
984-
985-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
977+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
986978
await export_completion_trace(
987979
platform_client=mock_platform_client,
988980
client=client,
@@ -1021,14 +1013,12 @@ async def test_export_completion_trace_skips_when_no_usage(
10211013
)
10221014

10231015
exporter = InMemorySpanExporter()
1024-
1025-
def _build_span_exporter_stub(access_token: str): # type: ignore[no-untyped-def]
1026-
assert access_token == "mock-jwt-token-12345"
1027-
return exporter
1016+
test_provider = TracerProvider()
1017+
test_provider.add_span_processor(SimpleSpanProcessor(exporter))
10281018

10291019
client = AsyncMock(spec=httpx.AsyncClient)
10301020

1031-
with patch("any_llm.providers.platform.utils._build_span_exporter", _build_span_exporter_stub):
1021+
with patch("any_llm.providers.platform.utils._get_or_create_tracer_provider", return_value=test_provider):
10321022
await export_completion_trace(
10331023
platform_client=mock_platform_client,
10341024
client=client,
@@ -1282,26 +1272,38 @@ async def test_trace_export_uses_bearer_token(
12821272
mock_completion: ChatCompletion,
12831273
) -> None:
12841274
"""Test that trace export uses Bearer token authentication."""
1275+
from any_llm.providers.platform import utils as platform_utils
1276+
12851277
mock_http_client = AsyncMock(spec=httpx.AsyncClient)
12861278
captured: dict[str, Any] = {}
12871279

12881280
def _exporter_factory(*args: Any, **kwargs: Any) -> InMemorySpanExporter: # type: ignore[type-arg]
12891281
captured.update(kwargs)
12901282
return InMemorySpanExporter()
12911283

1292-
with patch("any_llm.providers.platform.utils.OTLPSpanExporter", side_effect=_exporter_factory):
1293-
await export_completion_trace(
1294-
platform_client=mock_platform_client,
1295-
client=mock_http_client,
1296-
any_llm_key=any_llm_key,
1297-
provider="openai",
1298-
request_model="gpt-4",
1299-
completion=mock_completion,
1300-
start_time_ns=100,
1301-
end_time_ns=200,
1302-
client_name="test-client",
1303-
total_duration_ms=100.0,
1304-
)
1284+
# Clear the provider cache so a new provider is created with our mocked exporter
1285+
original_providers = platform_utils._providers.copy()
1286+
platform_utils._providers.clear()
1287+
try:
1288+
with patch("any_llm.providers.platform.utils.OTLPSpanExporter", side_effect=_exporter_factory):
1289+
await export_completion_trace(
1290+
platform_client=mock_platform_client,
1291+
client=mock_http_client,
1292+
any_llm_key=any_llm_key,
1293+
provider="openai",
1294+
request_model="gpt-4",
1295+
completion=mock_completion,
1296+
start_time_ns=100,
1297+
end_time_ns=200,
1298+
client_name="test-client",
1299+
total_duration_ms=100.0,
1300+
)
1301+
finally:
1302+
# Shutdown any providers created during the test and restore original state
1303+
for provider in platform_utils._providers.values():
1304+
provider.shutdown()
1305+
platform_utils._providers.clear()
1306+
platform_utils._providers.update(original_providers)
13051307

13061308
mock_platform_client._aensure_valid_token.assert_called_once_with(any_llm_key)
13071309

@@ -1321,6 +1323,7 @@ async def test_trace_export_includes_version_header(
13211323
) -> None:
13221324
"""Test that trace export includes library version in User-Agent header."""
13231325
from any_llm import __version__
1326+
from any_llm.providers.platform import utils as platform_utils
13241327

13251328
mock_http_client = AsyncMock(spec=httpx.AsyncClient)
13261329
captured: dict[str, Any] = {}
@@ -1329,17 +1332,26 @@ def _exporter_factory(*args: Any, **kwargs: Any) -> InMemorySpanExporter: # typ
13291332
captured.update(kwargs)
13301333
return InMemorySpanExporter()
13311334

1332-
with patch("any_llm.providers.platform.utils.OTLPSpanExporter", side_effect=_exporter_factory):
1333-
await export_completion_trace(
1334-
platform_client=mock_platform_client,
1335-
client=mock_http_client,
1336-
any_llm_key=any_llm_key,
1337-
provider="openai",
1338-
request_model="gpt-4",
1339-
completion=mock_completion,
1340-
start_time_ns=100,
1341-
end_time_ns=200,
1342-
)
1335+
# Clear the provider cache so a new provider is created with our mocked exporter
1336+
original_providers = platform_utils._providers.copy()
1337+
platform_utils._providers.clear()
1338+
try:
1339+
with patch("any_llm.providers.platform.utils.OTLPSpanExporter", side_effect=_exporter_factory):
1340+
await export_completion_trace(
1341+
platform_client=mock_platform_client,
1342+
client=mock_http_client,
1343+
any_llm_key=any_llm_key,
1344+
provider="openai",
1345+
request_model="gpt-4",
1346+
completion=mock_completion,
1347+
start_time_ns=100,
1348+
end_time_ns=200,
1349+
)
1350+
finally:
1351+
for provider in platform_utils._providers.values():
1352+
provider.shutdown()
1353+
platform_utils._providers.clear()
1354+
platform_utils._providers.update(original_providers)
13431355

13441356
headers = captured["headers"]
13451357

@@ -1561,3 +1573,45 @@ async def mock_stream() -> AsyncIterator[ChatCompletionChunk]:
15611573
assert len(chunks) == len(mock_streaming_chunks)
15621574
mock_export_trace.assert_called_once()
15631575
assert mock_export_trace.call_args.kwargs["provider"] == "mzai"
1576+
1577+
1578+
def test_tracer_provider_reused_for_same_token() -> None:
1579+
"""Test that _get_or_create_tracer_provider returns the same provider for the same token."""
1580+
from any_llm.providers.platform import utils as platform_utils
1581+
from any_llm.providers.platform.utils import _get_or_create_tracer_provider
1582+
1583+
original_providers = platform_utils._providers.copy()
1584+
platform_utils._providers.clear()
1585+
try:
1586+
provider_a = _get_or_create_tracer_provider("token-aaa")
1587+
provider_b = _get_or_create_tracer_provider("token-aaa")
1588+
assert provider_a is provider_b
1589+
1590+
provider_c = _get_or_create_tracer_provider("token-bbb")
1591+
assert provider_c is not provider_a
1592+
finally:
1593+
for provider in platform_utils._providers.values():
1594+
provider.shutdown()
1595+
platform_utils._providers.clear()
1596+
platform_utils._providers.update(original_providers)
1597+
1598+
1599+
def test_shutdown_telemetry_clears_providers() -> None:
1600+
"""Test that shutdown_telemetry shuts down all providers and clears the cache."""
1601+
from any_llm.providers.platform import utils as platform_utils
1602+
from any_llm.providers.platform.utils import _get_or_create_tracer_provider, shutdown_telemetry
1603+
1604+
original_providers = platform_utils._providers.copy()
1605+
platform_utils._providers.clear()
1606+
try:
1607+
_get_or_create_tracer_provider("token-xxx")
1608+
_get_or_create_tracer_provider("token-yyy")
1609+
assert len(platform_utils._providers) == 2
1610+
1611+
shutdown_telemetry()
1612+
1613+
assert len(platform_utils._providers) == 0
1614+
finally:
1615+
# Restore original state (shutdown_telemetry already cleared, but be safe)
1616+
platform_utils._providers.clear()
1617+
platform_utils._providers.update(original_providers)

0 commit comments

Comments
 (0)