@@ -2249,6 +2249,7 @@ async def empty_stream() -> AsyncIterator[ChatCompletionChunk]:
22492249 conversation_id = None ,
22502250 session_label = "session" ,
22512251 user_session_label = None ,
2252+ start_perf_counter_ns = 100 ,
22522253 any_llm_key = any_llm_key ,
22532254 llm_span = llm_span ,
22542255 trace_id = 123 ,
@@ -2259,6 +2260,111 @@ async def empty_stream() -> AsyncIterator[ChatCompletionChunk]:
22592260 collected = [chunk async for chunk in result ]
22602261 assert collected == []
22612262 llm_span .end .assert_called_once ()
2263+ llm_span .add_event .assert_not_called ()
2264+ assert not any (
2265+ call .args and call .args [0 ] == "anyllm.performance.ttft_ms" for call in llm_span .set_attribute .call_args_list
2266+ )
2267+
2268+
2269+ @pytest .mark .asyncio
2270+ @patch ("any_llm.providers.platform.platform.export_completion_trace" , new_callable = AsyncMock )
2271+ async def test_stream_with_usage_tracking_records_ttft_once (
2272+ mock_export_trace : AsyncMock ,
2273+ any_llm_key : str ,
2274+ ) -> None :
2275+ provider_instance = PlatformProvider (api_key = any_llm_key )
2276+ provider_instance ._provider = Mock (PROVIDER_NAME = "openai" )
2277+ llm_span = Mock ()
2278+
2279+ chunks = [
2280+ ChatCompletionChunk (
2281+ id = "chatcmpl-123" ,
2282+ model = "gpt-4" ,
2283+ created = 1234567890 ,
2284+ object = "chat.completion.chunk" ,
2285+ choices = [ChunkChoice (index = 0 , delta = ChoiceDelta (content = "Hello" ), finish_reason = None )],
2286+ ),
2287+ ChatCompletionChunk (
2288+ id = "chatcmpl-123" ,
2289+ model = "gpt-4" ,
2290+ created = 1234567890 ,
2291+ object = "chat.completion.chunk" ,
2292+ choices = [ChunkChoice (index = 0 , delta = ChoiceDelta (), finish_reason = "stop" )],
2293+ usage = CompletionUsage (prompt_tokens = 10 , completion_tokens = 5 , total_tokens = 15 ),
2294+ ),
2295+ ]
2296+
2297+ async def mock_stream () -> AsyncIterator [ChatCompletionChunk ]:
2298+ for chunk in chunks :
2299+ yield chunk
2300+
2301+ with patch ("any_llm.providers.platform.platform.time.perf_counter_ns" , return_value = 1_120_000_000 ):
2302+ result = provider_instance ._stream_with_usage_tracking (
2303+ stream = mock_stream (),
2304+ start_time_ns = 100 ,
2305+ request_model = "gpt-4" ,
2306+ conversation_id = None ,
2307+ session_label = "session" ,
2308+ user_session_label = None ,
2309+ start_perf_counter_ns = 1_000_000_000 ,
2310+ any_llm_key = any_llm_key ,
2311+ llm_span = llm_span ,
2312+ trace_id = 123 ,
2313+ access_token = None ,
2314+ trace_export_activated = False ,
2315+ )
2316+ collected = [chunk async for chunk in result ]
2317+
2318+ assert collected == chunks
2319+ ttft_calls = [
2320+ call
2321+ for call in llm_span .set_attribute .call_args_list
2322+ if call .args and call .args [0 ] == "anyllm.performance.ttft_ms"
2323+ ]
2324+ assert len (ttft_calls ) == 1
2325+ assert ttft_calls [0 ].args [1 ] == 120.0
2326+ llm_span .add_event .assert_called_once_with ("llm.first_token" , {"anyllm.performance.ttft_ms" : 120.0 })
2327+ mock_export_trace .assert_awaited_once ()
2328+
2329+
2330+ @pytest .mark .asyncio
2331+ @patch ("any_llm.providers.platform.platform.export_completion_trace" )
2332+ async def test_acompletion_non_streaming_does_not_set_ttft_attribute (
2333+ mock_export_trace : AsyncMock ,
2334+ any_llm_key : str ,
2335+ mock_decrypted_provider_key : DecryptedProviderKey ,
2336+ mock_completion : ChatCompletion ,
2337+ ) -> None :
2338+ provider_instance = PlatformProvider (api_key = any_llm_key )
2339+ provider_instance .provider = OpenaiProvider
2340+ await _init_provider (provider_instance , mock_decrypted_provider_key )
2341+ provider_instance .provider ._acompletion = AsyncMock (return_value = mock_completion ) # type: ignore[method-assign]
2342+
2343+ params = CompletionParams (
2344+ model_id = "gpt-4" ,
2345+ messages = [{"role" : "user" , "content" : "Hello" }],
2346+ stream = False ,
2347+ )
2348+
2349+ mock_span = Mock ()
2350+ mock_span .get_span_context .return_value = Mock (trace_id = 456 )
2351+ mock_tracer = Mock ()
2352+ mock_tracer .start_span .return_value = mock_span
2353+ mock_provider_tp = Mock ()
2354+ mock_provider_tp .get_tracer .return_value = mock_tracer
2355+
2356+ with (
2357+ patch .object (provider_instance .platform_client , "_aensure_valid_token" , AsyncMock (return_value = "jwt-token" )),
2358+ patch ("any_llm.providers.platform.platform._get_or_create_tracer_provider" , return_value = mock_provider_tp ),
2359+ patch ("any_llm.providers.platform.platform.activate_trace_export" ),
2360+ patch ("any_llm.providers.platform.platform.deactivate_trace_export" ),
2361+ ):
2362+ await provider_instance ._acompletion (params )
2363+
2364+ assert not any (
2365+ call .args and call .args [0 ] == "anyllm.performance.ttft_ms" for call in mock_span .set_attribute .call_args_list
2366+ )
2367+ mock_export_trace .assert_awaited_once ()
22622368
22632369
22642370@pytest .mark .asyncio
0 commit comments