77import httpx
88import pytest
99from any_llm_platform_client import DecryptedProviderKey
10+ from opentelemetry .sdk .trace import TracerProvider
11+ from opentelemetry .sdk .trace .export import SimpleSpanProcessor
1012from opentelemetry .sdk .trace .export .in_memory_span_exporter import InMemorySpanExporter
1113from pydantic import ValidationError
1214
@@ -494,14 +496,12 @@ async def test_export_completion_trace_success(
494496 any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
495497
496498 exporter = InMemorySpanExporter ()
497-
498- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
499- assert access_token == "mock-jwt-token-12345"
500- return exporter
499+ test_provider = TracerProvider ()
500+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
501501
502502 client = AsyncMock (spec = httpx .AsyncClient )
503503
504- with patch ("any_llm.providers.platform.utils._build_span_exporter " , _build_span_exporter_stub ):
504+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider " , return_value = test_provider ):
505505 await export_completion_trace (
506506 platform_client = mock_platform_client ,
507507 client = client ,
@@ -537,14 +537,12 @@ async def test_export_completion_trace_with_client_name(
537537 client_name = "my-test-client"
538538
539539 exporter = InMemorySpanExporter ()
540-
541- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
542- assert access_token == "mock-jwt-token-12345"
543- return exporter
540+ test_provider = TracerProvider ()
541+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
544542
545543 client = AsyncMock (spec = httpx .AsyncClient )
546544
547- with patch ("any_llm.providers.platform.utils._build_span_exporter " , _build_span_exporter_stub ):
545+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider " , return_value = test_provider ):
548546 await export_completion_trace (
549547 platform_client = mock_platform_client ,
550548 client = client ,
@@ -890,14 +888,12 @@ async def test_export_completion_trace_with_performance_metrics(
890888 any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
891889
892890 exporter = InMemorySpanExporter ()
891+ test_provider = TracerProvider ()
892+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
893893
894894 client = AsyncMock (spec = httpx .AsyncClient )
895895
896- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
897- assert access_token == "mock-jwt-token-12345"
898- return exporter
899-
900- with patch ("any_llm.providers.platform.utils._build_span_exporter" , _build_span_exporter_stub ):
896+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider" , return_value = test_provider ):
901897 await export_completion_trace (
902898 platform_client = mock_platform_client ,
903899 client = client ,
@@ -937,14 +933,12 @@ async def test_export_completion_trace_with_partial_performance_metrics(
937933 any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
938934
939935 exporter = InMemorySpanExporter ()
936+ test_provider = TracerProvider ()
937+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
940938
941939 client = AsyncMock (spec = httpx .AsyncClient )
942940
943- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
944- assert access_token == "mock-jwt-token-12345"
945- return exporter
946-
947- with patch ("any_llm.providers.platform.utils._build_span_exporter" , _build_span_exporter_stub ):
941+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider" , return_value = test_provider ):
948942 await export_completion_trace (
949943 platform_client = mock_platform_client ,
950944 client = client ,
@@ -975,14 +969,12 @@ async def test_export_completion_trace_without_performance_metrics(
975969 any_llm_key = "ANY.v1.kid123.fingerprint456-YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXoxMjM0NTY="
976970
977971 exporter = InMemorySpanExporter ()
972+ test_provider = TracerProvider ()
973+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
978974
979975 client = AsyncMock (spec = httpx .AsyncClient )
980976
981- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
982- assert access_token == "mock-jwt-token-12345"
983- return exporter
984-
985- with patch ("any_llm.providers.platform.utils._build_span_exporter" , _build_span_exporter_stub ):
977+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider" , return_value = test_provider ):
986978 await export_completion_trace (
987979 platform_client = mock_platform_client ,
988980 client = client ,
@@ -1021,14 +1013,12 @@ async def test_export_completion_trace_skips_when_no_usage(
10211013 )
10221014
10231015 exporter = InMemorySpanExporter ()
1024-
1025- def _build_span_exporter_stub (access_token : str ): # type: ignore[no-untyped-def]
1026- assert access_token == "mock-jwt-token-12345"
1027- return exporter
1016+ test_provider = TracerProvider ()
1017+ test_provider .add_span_processor (SimpleSpanProcessor (exporter ))
10281018
10291019 client = AsyncMock (spec = httpx .AsyncClient )
10301020
1031- with patch ("any_llm.providers.platform.utils._build_span_exporter " , _build_span_exporter_stub ):
1021+ with patch ("any_llm.providers.platform.utils._get_or_create_tracer_provider " , return_value = test_provider ):
10321022 await export_completion_trace (
10331023 platform_client = mock_platform_client ,
10341024 client = client ,
@@ -1282,26 +1272,38 @@ async def test_trace_export_uses_bearer_token(
12821272 mock_completion : ChatCompletion ,
12831273) -> None :
12841274 """Test that trace export uses Bearer token authentication."""
1275+ from any_llm .providers .platform import utils as platform_utils
1276+
12851277 mock_http_client = AsyncMock (spec = httpx .AsyncClient )
12861278 captured : dict [str , Any ] = {}
12871279
12881280 def _exporter_factory (* args : Any , ** kwargs : Any ) -> InMemorySpanExporter : # type: ignore[type-arg]
12891281 captured .update (kwargs )
12901282 return InMemorySpanExporter ()
12911283
1292- with patch ("any_llm.providers.platform.utils.OTLPSpanExporter" , side_effect = _exporter_factory ):
1293- await export_completion_trace (
1294- platform_client = mock_platform_client ,
1295- client = mock_http_client ,
1296- any_llm_key = any_llm_key ,
1297- provider = "openai" ,
1298- request_model = "gpt-4" ,
1299- completion = mock_completion ,
1300- start_time_ns = 100 ,
1301- end_time_ns = 200 ,
1302- client_name = "test-client" ,
1303- total_duration_ms = 100.0 ,
1304- )
1284+ # Clear the provider cache so a new provider is created with our mocked exporter
1285+ original_providers = platform_utils ._providers .copy ()
1286+ platform_utils ._providers .clear ()
1287+ try :
1288+ with patch ("any_llm.providers.platform.utils.OTLPSpanExporter" , side_effect = _exporter_factory ):
1289+ await export_completion_trace (
1290+ platform_client = mock_platform_client ,
1291+ client = mock_http_client ,
1292+ any_llm_key = any_llm_key ,
1293+ provider = "openai" ,
1294+ request_model = "gpt-4" ,
1295+ completion = mock_completion ,
1296+ start_time_ns = 100 ,
1297+ end_time_ns = 200 ,
1298+ client_name = "test-client" ,
1299+ total_duration_ms = 100.0 ,
1300+ )
1301+ finally :
1302+ # Shutdown any providers created during the test and restore original state
1303+ for provider in platform_utils ._providers .values ():
1304+ provider .shutdown ()
1305+ platform_utils ._providers .clear ()
1306+ platform_utils ._providers .update (original_providers )
13051307
13061308 mock_platform_client ._aensure_valid_token .assert_called_once_with (any_llm_key )
13071309
@@ -1321,6 +1323,7 @@ async def test_trace_export_includes_version_header(
13211323) -> None :
13221324 """Test that trace export includes library version in User-Agent header."""
13231325 from any_llm import __version__
1326+ from any_llm .providers .platform import utils as platform_utils
13241327
13251328 mock_http_client = AsyncMock (spec = httpx .AsyncClient )
13261329 captured : dict [str , Any ] = {}
@@ -1329,17 +1332,26 @@ def _exporter_factory(*args: Any, **kwargs: Any) -> InMemorySpanExporter: # typ
13291332 captured .update (kwargs )
13301333 return InMemorySpanExporter ()
13311334
1332- with patch ("any_llm.providers.platform.utils.OTLPSpanExporter" , side_effect = _exporter_factory ):
1333- await export_completion_trace (
1334- platform_client = mock_platform_client ,
1335- client = mock_http_client ,
1336- any_llm_key = any_llm_key ,
1337- provider = "openai" ,
1338- request_model = "gpt-4" ,
1339- completion = mock_completion ,
1340- start_time_ns = 100 ,
1341- end_time_ns = 200 ,
1342- )
1335+ # Clear the provider cache so a new provider is created with our mocked exporter
1336+ original_providers = platform_utils ._providers .copy ()
1337+ platform_utils ._providers .clear ()
1338+ try :
1339+ with patch ("any_llm.providers.platform.utils.OTLPSpanExporter" , side_effect = _exporter_factory ):
1340+ await export_completion_trace (
1341+ platform_client = mock_platform_client ,
1342+ client = mock_http_client ,
1343+ any_llm_key = any_llm_key ,
1344+ provider = "openai" ,
1345+ request_model = "gpt-4" ,
1346+ completion = mock_completion ,
1347+ start_time_ns = 100 ,
1348+ end_time_ns = 200 ,
1349+ )
1350+ finally :
1351+ for provider in platform_utils ._providers .values ():
1352+ provider .shutdown ()
1353+ platform_utils ._providers .clear ()
1354+ platform_utils ._providers .update (original_providers )
13431355
13441356 headers = captured ["headers" ]
13451357
@@ -1561,3 +1573,45 @@ async def mock_stream() -> AsyncIterator[ChatCompletionChunk]:
15611573 assert len (chunks ) == len (mock_streaming_chunks )
15621574 mock_export_trace .assert_called_once ()
15631575 assert mock_export_trace .call_args .kwargs ["provider" ] == "mzai"
1576+
1577+
1578+ def test_tracer_provider_reused_for_same_token () -> None :
1579+ """Test that _get_or_create_tracer_provider returns the same provider for the same token."""
1580+ from any_llm .providers .platform import utils as platform_utils
1581+ from any_llm .providers .platform .utils import _get_or_create_tracer_provider
1582+
1583+ original_providers = platform_utils ._providers .copy ()
1584+ platform_utils ._providers .clear ()
1585+ try :
1586+ provider_a = _get_or_create_tracer_provider ("token-aaa" )
1587+ provider_b = _get_or_create_tracer_provider ("token-aaa" )
1588+ assert provider_a is provider_b
1589+
1590+ provider_c = _get_or_create_tracer_provider ("token-bbb" )
1591+ assert provider_c is not provider_a
1592+ finally :
1593+ for provider in platform_utils ._providers .values ():
1594+ provider .shutdown ()
1595+ platform_utils ._providers .clear ()
1596+ platform_utils ._providers .update (original_providers )
1597+
1598+
1599+ def test_shutdown_telemetry_clears_providers () -> None :
1600+ """Test that shutdown_telemetry shuts down all providers and clears the cache."""
1601+ from any_llm .providers .platform import utils as platform_utils
1602+ from any_llm .providers .platform .utils import _get_or_create_tracer_provider , shutdown_telemetry
1603+
1604+ original_providers = platform_utils ._providers .copy ()
1605+ platform_utils ._providers .clear ()
1606+ try :
1607+ _get_or_create_tracer_provider ("token-xxx" )
1608+ _get_or_create_tracer_provider ("token-yyy" )
1609+ assert len (platform_utils ._providers ) == 2
1610+
1611+ shutdown_telemetry ()
1612+
1613+ assert len (platform_utils ._providers ) == 0
1614+ finally :
1615+ # Restore original state (shutdown_telemetry already cleared, but be safe)
1616+ platform_utils ._providers .clear ()
1617+ platform_utils ._providers .update (original_providers )
0 commit comments