diff --git a/src/any_llm/gateway/routes/chat.py b/src/any_llm/gateway/routes/chat.py index 22a79a5f..fe50883c 100644 --- a/src/any_llm/gateway/routes/chat.py +++ b/src/any_llm/gateway/routes/chat.py @@ -1,4 +1,3 @@ -import json import uuid from collections.abc import AsyncIterator from datetime import UTC, datetime @@ -18,6 +17,7 @@ from any_llm.gateway.db import APIKey, ModelPricing, UsageLog, User, get_db from any_llm.gateway.log_config import logger from any_llm.gateway.rate_limit import RateLimitInfo, check_rate_limit +from any_llm.gateway.streaming import OPENAI_STREAM_FORMAT, streaming_generator from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionUsage router = APIRouter(prefix="/v1/chat", tags=["chat"]) @@ -230,67 +230,55 @@ async def chat_completions( try: if request.stream: - async def generate() -> AsyncIterator[str]: - prompt_tokens = 0 - completion_tokens = 0 - total_tokens = 0 - - try: - stream: AsyncIterator[ChatCompletionChunk] = await acompletion(**completion_kwargs) # type: ignore[assignment] - async for chunk in stream: - if chunk.usage: - # Take the last non-zero value for each field. This works for - # providers that report cumulative totals (last = total) and - # providers that only report usage on the final chunk. - if chunk.usage.prompt_tokens: - prompt_tokens = chunk.usage.prompt_tokens - if chunk.usage.completion_tokens: - completion_tokens = chunk.usage.completion_tokens - if chunk.usage.total_tokens: - total_tokens = chunk.usage.total_tokens - - yield f"data: {chunk.model_dump_json()}\n\n" - yield "data: [DONE]\n\n" - - # Log aggregated usage - if prompt_tokens or completion_tokens or total_tokens: - usage_data = CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - await log_usage( - db=db, - api_key_obj=api_key, - model=model, - provider=provider, - endpoint="/v1/chat/completions", - user_id=user_id, - usage_override=usage_data, - ) - else: - # This should never happen. - logger.warning(f"No usage data received from streaming response for model {model}") - except Exception as e: - error_data = {"error": {"message": "An error occurred during streaming", "type": "server_error"}} - yield f"data: {json.dumps(error_data)}\n\n" - yield "data: [DONE]\n\n" - try: - await log_usage( - db=db, - api_key_obj=api_key, - model=model, - provider=provider, - endpoint="/v1/chat/completions", - user_id=user_id, - error=str(e), - ) - except Exception as log_err: - logger.error(f"Failed to log streaming error usage: {log_err}") - logger.error(f"Streaming error for {provider}:{model}: {e}") + def _format_chunk(chunk: ChatCompletionChunk) -> str: + return f"data: {chunk.model_dump_json()}\n\n" + def _extract_usage(chunk: ChatCompletionChunk) -> CompletionUsage | None: + if not chunk.usage: + return None + return CompletionUsage( + prompt_tokens=chunk.usage.prompt_tokens or 0, + completion_tokens=chunk.usage.completion_tokens or 0, + total_tokens=chunk.usage.total_tokens or 0, + ) + + async def _on_complete(usage_data: CompletionUsage) -> None: + await log_usage( + db=db, + api_key_obj=api_key, + model=model, + provider=provider, + endpoint="/v1/chat/completions", + user_id=user_id, + usage_override=usage_data, + ) + + async def _on_error(error: str) -> None: + await log_usage( + db=db, + api_key_obj=api_key, + model=model, + provider=provider, + endpoint="/v1/chat/completions", + user_id=user_id, + error=error, + ) + + stream: AsyncIterator[ChatCompletionChunk] = await acompletion(**completion_kwargs) # type: ignore[assignment] rl_headers = rate_limit_headers(rate_limit_info) if rate_limit_info else {} - return StreamingResponse(generate(), media_type="text/event-stream", headers=rl_headers) + return StreamingResponse( + streaming_generator( + stream=stream, + format_chunk=_format_chunk, + extract_usage=_extract_usage, + fmt=OPENAI_STREAM_FORMAT, + on_complete=_on_complete, + on_error=_on_error, + label=f"{provider}:{model}", + ), + media_type="text/event-stream", + headers=rl_headers, + ) completion: ChatCompletion = await acompletion(**completion_kwargs) # type: ignore[assignment] await log_usage( diff --git a/src/any_llm/gateway/routes/messages.py b/src/any_llm/gateway/routes/messages.py index 4cd7a1e7..8a816f68 100644 --- a/src/any_llm/gateway/routes/messages.py +++ b/src/any_llm/gateway/routes/messages.py @@ -1,5 +1,3 @@ -import json -from collections.abc import AsyncIterator from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Request, Response, status @@ -16,8 +14,9 @@ from any_llm.gateway.log_config import logger from any_llm.gateway.rate_limit import check_rate_limit from any_llm.gateway.routes.chat import get_provider_kwargs, log_usage, rate_limit_headers +from any_llm.gateway.streaming import ANTHROPIC_STREAM_FORMAT, streaming_generator from any_llm.types.completion import CompletionUsage -from any_llm.types.messages import MessageResponse, MessageStreamEvent # noqa: TC001 +from any_llm.types.messages import MessageResponse, MessageStreamEvent router = APIRouter(prefix="/v1", tags=["messages"]) @@ -112,58 +111,57 @@ async def create_message( if request.stream: call_kwargs["stream"] = True - async def generate() -> AsyncIterator[str]: - input_tokens = 0 - output_tokens = 0 - - try: - stream: AsyncIterator[MessageStreamEvent] = await amessages(**call_kwargs) # type: ignore[assignment] - async for event in stream: - if event.usage: - if event.usage.input_tokens: - input_tokens = event.usage.input_tokens - if event.usage.output_tokens: - output_tokens = event.usage.output_tokens - yield f"event: {event.type}\ndata: {event.model_dump_json(exclude_none=True)}\n\n" - yield "event: done\ndata: {}\n\n" - - if input_tokens or output_tokens: - usage_data = CompletionUsage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ) - await log_usage( - db=db, - api_key_obj=api_key, - model=model, - provider=provider, - endpoint="/v1/messages", - user_id=user_id, - usage_override=usage_data, - ) - except Exception as e: - error_data = { - "type": "error", - "error": {"type": "api_error", "message": "An error occurred during streaming"}, - } - yield f"event: error\ndata: {json.dumps(error_data)}\n\n" - try: - await log_usage( - db=db, - api_key_obj=api_key, - model=model, - provider=provider, - endpoint="/v1/messages", - user_id=user_id, - error=str(e), - ) - except Exception as log_err: - logger.error(f"Failed to log streaming error usage: {log_err}") - logger.error(f"Streaming error for {provider}:{model}: {e}") - + def _format_chunk(event: MessageStreamEvent) -> str: + return f"event: {event.type}\ndata: {event.model_dump_json(exclude_none=True)}\n\n" + + def _extract_usage(event: MessageStreamEvent) -> CompletionUsage | None: + if not event.usage: + return None + input_tokens = event.usage.input_tokens or 0 + output_tokens = event.usage.output_tokens or 0 + return CompletionUsage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=input_tokens + output_tokens, + ) + + async def _on_complete(usage_data: CompletionUsage) -> None: + await log_usage( + db=db, + api_key_obj=api_key, + model=model, + provider=provider, + endpoint="/v1/messages", + user_id=user_id, + usage_override=usage_data, + ) + + async def _on_error(error: str) -> None: + await log_usage( + db=db, + api_key_obj=api_key, + model=model, + provider=provider, + endpoint="/v1/messages", + user_id=user_id, + error=error, + ) + + msg_stream = await amessages(**call_kwargs) rl_headers = rate_limit_headers(rate_limit_info) if rate_limit_info else {} - return StreamingResponse(generate(), media_type="text/event-stream", headers=rl_headers) + return StreamingResponse( + streaming_generator( + stream=msg_stream, # type: ignore[arg-type] + format_chunk=_format_chunk, + extract_usage=_extract_usage, + fmt=ANTHROPIC_STREAM_FORMAT, + on_complete=_on_complete, + on_error=_on_error, + label=f"{provider}:{model}", + ), + media_type="text/event-stream", + headers=rl_headers, + ) result: MessageResponse = await amessages(**call_kwargs) # type: ignore[assignment] diff --git a/src/any_llm/gateway/streaming.py b/src/any_llm/gateway/streaming.py new file mode 100644 index 00000000..c013739f --- /dev/null +++ b/src/any_llm/gateway/streaming.py @@ -0,0 +1,91 @@ +"""Shared SSE streaming utilities for gateway routes.""" + +import json +from collections.abc import AsyncIterator, Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from any_llm.gateway.log_config import logger +from any_llm.types.completion import CompletionUsage + + +@dataclass(frozen=True) +class StreamFormat: + """SSE formatting configuration for a streaming protocol.""" + + done_marker: str + error_payload: str + yield_done_on_error: bool + + +_OPENAI_ERROR = json.dumps({"error": {"message": "An error occurred during streaming", "type": "server_error"}}) +_ANTHROPIC_ERROR = json.dumps( + {"type": "error", "error": {"type": "api_error", "message": "An error occurred during streaming"}} +) + +OPENAI_STREAM_FORMAT = StreamFormat( + done_marker="data: [DONE]\n\n", + error_payload=f"data: {_OPENAI_ERROR}\n\n", + yield_done_on_error=True, +) + +ANTHROPIC_STREAM_FORMAT = StreamFormat( + done_marker="event: done\ndata: {}\n\n", + error_payload=f"event: error\ndata: {_ANTHROPIC_ERROR}\n\n", + yield_done_on_error=False, +) + + +def _merge_usage(current: CompletionUsage, update: CompletionUsage) -> CompletionUsage: + """Merge usage data, keeping the last non-zero value for each field.""" + return CompletionUsage( + prompt_tokens=update.prompt_tokens or current.prompt_tokens, + completion_tokens=update.completion_tokens or current.completion_tokens, + total_tokens=update.total_tokens or current.total_tokens, + ) + + +async def streaming_generator( + stream: AsyncIterator[Any], + format_chunk: Callable[[Any], str], + extract_usage: Callable[[Any], CompletionUsage | None], + fmt: StreamFormat, + on_complete: Callable[[CompletionUsage], Awaitable[None]], + on_error: Callable[[str], Awaitable[None]], + label: str, +) -> AsyncIterator[str]: + """Shared SSE streaming generator with usage tracking and error handling. + + Args: + stream: Async iterator of chunks from the provider + format_chunk: Formats a chunk into an SSE string + extract_usage: Extracts usage from a chunk, or returns None if no usage present + fmt: SSE format configuration (done marker, error payload, etc.) + on_complete: Called with aggregated usage after successful streaming + on_error: Called with error message on failure + label: Identifier for error log messages (e.g., "openai:gpt-4") + + """ + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + has_usage = False + + try: + async for chunk in stream: + chunk_usage = extract_usage(chunk) + if chunk_usage: + usage = _merge_usage(usage, chunk_usage) + has_usage = True + yield format_chunk(chunk) + yield fmt.done_marker + + if has_usage: + await on_complete(usage) + except Exception as e: + yield fmt.error_payload + if fmt.yield_done_on_error: + yield fmt.done_marker + try: + await on_error(str(e)) + except Exception as log_err: + logger.error("Failed to log streaming error usage: %s", log_err) + logger.error("Streaming error for %s: %s", label, e) diff --git a/tests/gateway/test_streaming_error_event.py b/tests/gateway/test_streaming_error_event.py index db96f28f..9fb0932f 100644 --- a/tests/gateway/test_streaming_error_event.py +++ b/tests/gateway/test_streaming_error_event.py @@ -1,17 +1,21 @@ -"""Tests for SSE error event emission on streaming failures.""" +"""Tests for error handling on streaming requests.""" -import json from typing import Any from fastapi.testclient import TestClient -def test_streaming_error_emits_sse_error_event( +def test_streaming_creation_error_returns_http_error( client: TestClient, api_key_header: dict[str, str], test_user: dict[str, Any], ) -> None: - """Test that a streaming request to an invalid model emits an SSE error event.""" + """Test that a streaming request to an invalid model returns an HTTP error. + + When the stream cannot be created (e.g., invalid model, missing API key), + the gateway returns a proper HTTP error response rather than starting a + stream and emitting an SSE error event. + """ response = client.post( "/v1/chat/completions", json={ @@ -23,33 +27,5 @@ def test_streaming_error_emits_sse_error_event( headers=api_key_header, ) - # The response should be 200 (streaming started) but contain an error event - assert response.status_code == 200 - - events: list[str] = [] - found_error_event = False - for line in response.iter_lines(): - if line.startswith("data: "): - data_str = line[6:] - events.append(data_str) - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - if "error" in chunk: - found_error_event = True - assert "message" in chunk["error"] - assert "type" in chunk["error"] - except json.JSONDecodeError: - continue - - assert found_error_event, "Should have received an SSE error event for invalid model" - - # [DONE] must be the last event and must come after the error event - assert events[-1] == "[DONE]", "Stream should end with [DONE] after error event" - - # No [DONE] should appear before the error event - done_indices = [i for i, e in enumerate(events) if e == "[DONE]"] - error_indices = [i for i, e in enumerate(events) if e != "[DONE]" and "error" in e] - assert error_indices, "Should have found an error event" - assert done_indices[-1] > error_indices[0], "[DONE] must come after the error event" + assert response.status_code == 500 + assert "provider" in response.json()["detail"].lower() diff --git a/tests/gateway/test_streaming_generator.py b/tests/gateway/test_streaming_generator.py new file mode 100644 index 00000000..1f8c32f2 --- /dev/null +++ b/tests/gateway/test_streaming_generator.py @@ -0,0 +1,183 @@ +"""Tests for the shared streaming_generator utility.""" + +from collections.abc import AsyncIterator + +import pytest + +from any_llm.gateway.streaming import ( + ANTHROPIC_STREAM_FORMAT, + OPENAI_STREAM_FORMAT, + streaming_generator, +) +from any_llm.types.completion import CompletionUsage + +_PROVIDER_CRASHED = "provider crashed" +_LOGGING_FAILED = "logging failed too" + + +def _format_chunk(chunk: str) -> str: + return f"data: {chunk}\n\n" + + +def _extract_usage(chunk: str) -> CompletionUsage | None: + if chunk == "usage": + return CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15) + return None + + +async def _items(*values: str) -> AsyncIterator[str]: + for v in values: + yield v + + +@pytest.mark.asyncio +async def test_streaming_generator_success_with_usage() -> None: + """Test successful streaming with usage tracking.""" + completed_usage: list[CompletionUsage] = [] + + async def on_complete(usage: CompletionUsage) -> None: + completed_usage.append(usage) + + async def on_error(error: str) -> None: + pytest.fail("on_error should not be called") + + events: list[str] = [] + async for event in streaming_generator( + stream=_items("hello", "usage"), + format_chunk=_format_chunk, + extract_usage=_extract_usage, + fmt=OPENAI_STREAM_FORMAT, + on_complete=on_complete, + on_error=on_error, + label="test:model", + ): + events.append(event) + + assert events == ["data: hello\n\n", "data: usage\n\n", "data: [DONE]\n\n"] + assert len(completed_usage) == 1 + assert completed_usage[0].prompt_tokens == 10 + assert completed_usage[0].completion_tokens == 5 + + +@pytest.mark.asyncio +async def test_streaming_generator_no_usage_skips_on_complete() -> None: + """Test that on_complete is not called when no usage data is received.""" + completed = False + + async def on_complete(usage: CompletionUsage) -> None: + nonlocal completed + completed = True + + async def on_error(error: str) -> None: + pytest.fail("on_error should not be called") + + events: list[str] = [] + async for event in streaming_generator( + stream=_items("hello"), + format_chunk=_format_chunk, + extract_usage=lambda _: None, + fmt=OPENAI_STREAM_FORMAT, + on_complete=on_complete, + on_error=on_error, + label="test:model", + ): + events.append(event) + + assert events == ["data: hello\n\n", "data: [DONE]\n\n"] + assert not completed + + +@pytest.mark.asyncio +async def test_streaming_generator_error_openai_format() -> None: + """Test error handling emits OpenAI-style error and [DONE].""" + error_logged: list[str] = [] + + async def on_complete(usage: CompletionUsage) -> None: + pytest.fail("on_complete should not be called on error") + + async def on_error(error: str) -> None: + error_logged.append(error) + + async def _failing_stream() -> AsyncIterator[str]: + yield "hello" + raise RuntimeError(_PROVIDER_CRASHED) + + events: list[str] = [] + async for event in streaming_generator( + stream=_failing_stream(), + format_chunk=_format_chunk, + extract_usage=lambda _: None, + fmt=OPENAI_STREAM_FORMAT, + on_complete=on_complete, + on_error=on_error, + label="test:model", + ): + events.append(event) + + assert events[0] == "data: hello\n\n" + assert "server_error" in events[1] + assert events[2] == "data: [DONE]\n\n" + assert error_logged == [_PROVIDER_CRASHED] + + +@pytest.mark.asyncio +async def test_streaming_generator_error_anthropic_format() -> None: + """Test error handling emits Anthropic-style error without done marker.""" + error_logged: list[str] = [] + + async def on_complete(usage: CompletionUsage) -> None: + pytest.fail("on_complete should not be called on error") + + async def on_error(error: str) -> None: + error_logged.append(error) + + async def _failing_stream() -> AsyncIterator[str]: + raise RuntimeError(_PROVIDER_CRASHED) + yield # pragma: no cover + + events: list[str] = [] + async for event in streaming_generator( + stream=_failing_stream(), + format_chunk=_format_chunk, + extract_usage=lambda _: None, + fmt=ANTHROPIC_STREAM_FORMAT, + on_complete=on_complete, + on_error=on_error, + label="test:model", + ): + events.append(event) + + assert len(events) == 1 + assert "api_error" in events[0] + assert events[0].startswith("event: error\n") + assert error_logged == [_PROVIDER_CRASHED] + + +@pytest.mark.asyncio +async def test_streaming_generator_error_logging_failure_is_swallowed() -> None: + """Test that failures in on_error don't propagate to the caller.""" + + async def on_complete(usage: CompletionUsage) -> None: + pytest.fail("on_complete should not be called on error") + + async def on_error(error: str) -> None: + raise RuntimeError(_LOGGING_FAILED) + + async def _failing_stream() -> AsyncIterator[str]: + raise RuntimeError(_PROVIDER_CRASHED) + yield # pragma: no cover + + events: list[str] = [] + async for event in streaming_generator( + stream=_failing_stream(), + format_chunk=_format_chunk, + extract_usage=lambda _: None, + fmt=OPENAI_STREAM_FORMAT, + on_complete=on_complete, + on_error=on_error, + label="test:model", + ): + events.append(event) + + assert "server_error" in events[0] + assert events[1] == "data: [DONE]\n\n" diff --git a/tests/gateway/test_streaming_token_aggregation.py b/tests/gateway/test_streaming_token_aggregation.py index c1353275..6b254f0a 100644 --- a/tests/gateway/test_streaming_token_aggregation.py +++ b/tests/gateway/test_streaming_token_aggregation.py @@ -1,57 +1,45 @@ """Tests for streaming token aggregation logic.""" +from any_llm.gateway.streaming import _merge_usage from any_llm.types.completion import CompletionUsage -def test_last_nonzero_value_aggregation_cumulative() -> None: - """Test that last-non-zero-value works for cumulative reporting providers. - - Cumulative providers send increasing totals: [10, 20, 30]. - Taking the last non-zero value gives 30 (correct total). - """ - chunks_usage = [ +def test_merge_usage_cumulative() -> None: + """Test that _merge_usage keeps last non-zero value for cumulative providers.""" + chunks = [ CompletionUsage(prompt_tokens=100, completion_tokens=10, total_tokens=110), CompletionUsage(prompt_tokens=100, completion_tokens=20, total_tokens=120), CompletionUsage(prompt_tokens=100, completion_tokens=30, total_tokens=130), ] - prompt_tokens = 0 - completion_tokens = 0 - total_tokens = 0 - - for usage in chunks_usage: - if usage.prompt_tokens: - prompt_tokens = usage.prompt_tokens - if usage.completion_tokens: - completion_tokens = usage.completion_tokens - if usage.total_tokens: - total_tokens = usage.total_tokens - - assert prompt_tokens == 100 - assert completion_tokens == 30 # Last value (cumulative total) - assert total_tokens == 130 - - -def test_last_nonzero_value_aggregation_final_chunk_only() -> None: - """Test that last-non-zero-value works for final-chunk-only providers. - - Providers that only report on the final chunk send usage once. - Taking the last non-zero value gives that single value (correct). - """ - # Most chunks have no usage, final chunk has the totals - prompt_tokens = 0 - completion_tokens = 0 - total_tokens = 0 - - # Simulate: only the last chunk has usage - final_usage = CompletionUsage(prompt_tokens=50, completion_tokens=200, total_tokens=250) - if final_usage.prompt_tokens: - prompt_tokens = final_usage.prompt_tokens - if final_usage.completion_tokens: - completion_tokens = final_usage.completion_tokens - if final_usage.total_tokens: - total_tokens = final_usage.total_tokens - - assert prompt_tokens == 50 - assert completion_tokens == 200 - assert total_tokens == 250 + result = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + for chunk in chunks: + result = _merge_usage(result, chunk) + + assert result.prompt_tokens == 100 + assert result.completion_tokens == 30 + assert result.total_tokens == 130 + + +def test_merge_usage_final_chunk_only() -> None: + """Test that _merge_usage works when only the last chunk has usage.""" + base = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + final = CompletionUsage(prompt_tokens=50, completion_tokens=200, total_tokens=250) + + result = _merge_usage(base, final) + + assert result.prompt_tokens == 50 + assert result.completion_tokens == 200 + assert result.total_tokens == 250 + + +def test_merge_usage_preserves_current_on_zero_update() -> None: + """Test that _merge_usage preserves current values when update has zeros.""" + current = CompletionUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + update = CompletionUsage(prompt_tokens=0, completion_tokens=75, total_tokens=0) + + result = _merge_usage(current, update) + + assert result.prompt_tokens == 100 + assert result.completion_tokens == 75 + assert result.total_tokens == 150