Skip to content

Commit 7915adc

Browse files
committed
feat(platform): record TTFT metric on first streamed token
1 parent ade0101 commit 7915adc

File tree

2 files changed

+115
-0
lines changed

2 files changed

+115
-0
lines changed

src/any_llm/providers/platform/platform.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ async def _acompletion(
171171

172172
await self._ensure_provider_initialized()
173173
start_time_ns = time.time_ns()
174+
start_perf_counter_ns = time.perf_counter_ns()
174175
any_llm_key = self.any_llm_key
175176
if any_llm_key is None:
176177
msg = "any_llm_key is required for platform provider"
@@ -258,6 +259,7 @@ async def _acompletion(
258259
params.user,
259260
session_trace_label,
260261
user_session_label,
262+
start_perf_counter_ns,
261263
any_llm_key,
262264
llm_span,
263265
trace_id,
@@ -335,6 +337,7 @@ async def _stream_with_usage_tracking(
335337
conversation_id: str | None,
336338
session_label: str,
337339
user_session_label: str | None,
340+
start_perf_counter_ns: int,
338341
any_llm_key: str,
339342
llm_span: trace.Span,
340343
trace_id: int,
@@ -343,10 +346,16 @@ async def _stream_with_usage_tracking(
343346
) -> AsyncIterator[ChatCompletionChunk]:
344347
"""Wrap the stream to export a trace after completion."""
345348
chunks: list[ChatCompletionChunk] = []
349+
first_chunk_received = False
346350

347351
try:
348352
with trace.use_span(llm_span, end_on_exit=False):
349353
async for chunk in stream:
354+
if not first_chunk_received:
355+
first_chunk_received = True
356+
ttft_ms = (time.perf_counter_ns() - start_perf_counter_ns) / 1_000_000
357+
llm_span.set_attribute("anyllm.performance.ttft_ms", ttft_ms)
358+
llm_span.add_event("llm.first_token", {"anyllm.performance.ttft_ms": ttft_ms})
350359
chunks.append(chunk)
351360
yield chunk
352361

tests/unit/providers/test_platform_provider.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,6 +2249,7 @@ async def empty_stream() -> AsyncIterator[ChatCompletionChunk]:
22492249
conversation_id=None,
22502250
session_label="session",
22512251
user_session_label=None,
2252+
start_perf_counter_ns=100,
22522253
any_llm_key=any_llm_key,
22532254
llm_span=llm_span,
22542255
trace_id=123,
@@ -2259,6 +2260,111 @@ async def empty_stream() -> AsyncIterator[ChatCompletionChunk]:
22592260
collected = [chunk async for chunk in result]
22602261
assert collected == []
22612262
llm_span.end.assert_called_once()
2263+
llm_span.add_event.assert_not_called()
2264+
assert not any(
2265+
call.args and call.args[0] == "anyllm.performance.ttft_ms" for call in llm_span.set_attribute.call_args_list
2266+
)
2267+
2268+
2269+
@pytest.mark.asyncio
2270+
@patch("any_llm.providers.platform.platform.export_completion_trace", new_callable=AsyncMock)
2271+
async def test_stream_with_usage_tracking_records_ttft_once(
2272+
mock_export_trace: AsyncMock,
2273+
any_llm_key: str,
2274+
) -> None:
2275+
provider_instance = PlatformProvider(api_key=any_llm_key)
2276+
provider_instance._provider = Mock(PROVIDER_NAME="openai")
2277+
llm_span = Mock()
2278+
2279+
chunks = [
2280+
ChatCompletionChunk(
2281+
id="chatcmpl-123",
2282+
model="gpt-4",
2283+
created=1234567890,
2284+
object="chat.completion.chunk",
2285+
choices=[ChunkChoice(index=0, delta=ChoiceDelta(content="Hello"), finish_reason=None)],
2286+
),
2287+
ChatCompletionChunk(
2288+
id="chatcmpl-123",
2289+
model="gpt-4",
2290+
created=1234567890,
2291+
object="chat.completion.chunk",
2292+
choices=[ChunkChoice(index=0, delta=ChoiceDelta(), finish_reason="stop")],
2293+
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15),
2294+
),
2295+
]
2296+
2297+
async def mock_stream() -> AsyncIterator[ChatCompletionChunk]:
2298+
for chunk in chunks:
2299+
yield chunk
2300+
2301+
with patch("any_llm.providers.platform.platform.time.perf_counter_ns", return_value=1_120_000_000):
2302+
result = provider_instance._stream_with_usage_tracking(
2303+
stream=mock_stream(),
2304+
start_time_ns=100,
2305+
request_model="gpt-4",
2306+
conversation_id=None,
2307+
session_label="session",
2308+
user_session_label=None,
2309+
start_perf_counter_ns=1_000_000_000,
2310+
any_llm_key=any_llm_key,
2311+
llm_span=llm_span,
2312+
trace_id=123,
2313+
access_token=None,
2314+
trace_export_activated=False,
2315+
)
2316+
collected = [chunk async for chunk in result]
2317+
2318+
assert collected == chunks
2319+
ttft_calls = [
2320+
call
2321+
for call in llm_span.set_attribute.call_args_list
2322+
if call.args and call.args[0] == "anyllm.performance.ttft_ms"
2323+
]
2324+
assert len(ttft_calls) == 1
2325+
assert ttft_calls[0].args[1] == 120.0
2326+
llm_span.add_event.assert_called_once_with("llm.first_token", {"anyllm.performance.ttft_ms": 120.0})
2327+
mock_export_trace.assert_awaited_once()
2328+
2329+
2330+
@pytest.mark.asyncio
2331+
@patch("any_llm.providers.platform.platform.export_completion_trace")
2332+
async def test_acompletion_non_streaming_does_not_set_ttft_attribute(
2333+
mock_export_trace: AsyncMock,
2334+
any_llm_key: str,
2335+
mock_decrypted_provider_key: DecryptedProviderKey,
2336+
mock_completion: ChatCompletion,
2337+
) -> None:
2338+
provider_instance = PlatformProvider(api_key=any_llm_key)
2339+
provider_instance.provider = OpenaiProvider
2340+
await _init_provider(provider_instance, mock_decrypted_provider_key)
2341+
provider_instance.provider._acompletion = AsyncMock(return_value=mock_completion) # type: ignore[method-assign]
2342+
2343+
params = CompletionParams(
2344+
model_id="gpt-4",
2345+
messages=[{"role": "user", "content": "Hello"}],
2346+
stream=False,
2347+
)
2348+
2349+
mock_span = Mock()
2350+
mock_span.get_span_context.return_value = Mock(trace_id=456)
2351+
mock_tracer = Mock()
2352+
mock_tracer.start_span.return_value = mock_span
2353+
mock_provider_tp = Mock()
2354+
mock_provider_tp.get_tracer.return_value = mock_tracer
2355+
2356+
with (
2357+
patch.object(provider_instance.platform_client, "_aensure_valid_token", AsyncMock(return_value="jwt-token")),
2358+
patch("any_llm.providers.platform.platform._get_or_create_tracer_provider", return_value=mock_provider_tp),
2359+
patch("any_llm.providers.platform.platform.activate_trace_export"),
2360+
patch("any_llm.providers.platform.platform.deactivate_trace_export"),
2361+
):
2362+
await provider_instance._acompletion(params)
2363+
2364+
assert not any(
2365+
call.args and call.args[0] == "anyllm.performance.ttft_ms" for call in mock_span.set_attribute.call_args_list
2366+
)
2367+
mock_export_trace.assert_awaited_once()
22622368

22632369

22642370
@pytest.mark.asyncio

0 commit comments

Comments
 (0)