Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 48 additions & 60 deletions src/any_llm/gateway/routes/chat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import uuid
from collections.abc import AsyncIterator
from datetime import UTC, datetime
Expand All @@ -18,6 +17,7 @@
from any_llm.gateway.db import APIKey, ModelPricing, UsageLog, User, get_db
from any_llm.gateway.log_config import logger
from any_llm.gateway.rate_limit import RateLimitInfo, check_rate_limit
from any_llm.gateway.streaming import OPENAI_STREAM_FORMAT, streaming_generator
from any_llm.types.completion import ChatCompletion, ChatCompletionChunk, CompletionUsage

router = APIRouter(prefix="/v1/chat", tags=["chat"])
Expand Down Expand Up @@ -230,67 +230,55 @@ async def chat_completions(
try:
if request.stream:

async def generate() -> AsyncIterator[str]:
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0

try:
stream: AsyncIterator[ChatCompletionChunk] = await acompletion(**completion_kwargs) # type: ignore[assignment]
async for chunk in stream:
if chunk.usage:
# Take the last non-zero value for each field. This works for
# providers that report cumulative totals (last = total) and
# providers that only report usage on the final chunk.
if chunk.usage.prompt_tokens:
prompt_tokens = chunk.usage.prompt_tokens
if chunk.usage.completion_tokens:
completion_tokens = chunk.usage.completion_tokens
if chunk.usage.total_tokens:
total_tokens = chunk.usage.total_tokens

yield f"data: {chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"

# Log aggregated usage
if prompt_tokens or completion_tokens or total_tokens:
usage_data = CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/chat/completions",
user_id=user_id,
usage_override=usage_data,
)
else:
# This should never happen.
logger.warning(f"No usage data received from streaming response for model {model}")
except Exception as e:
error_data = {"error": {"message": "An error occurred during streaming", "type": "server_error"}}
yield f"data: {json.dumps(error_data)}\n\n"
yield "data: [DONE]\n\n"
try:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/chat/completions",
user_id=user_id,
error=str(e),
)
except Exception as log_err:
logger.error(f"Failed to log streaming error usage: {log_err}")
logger.error(f"Streaming error for {provider}:{model}: {e}")
def _format_chunk(chunk: ChatCompletionChunk) -> str:
return f"data: {chunk.model_dump_json()}\n\n"

def _extract_usage(chunk: ChatCompletionChunk) -> CompletionUsage | None:
if not chunk.usage:
return None
return CompletionUsage(
prompt_tokens=chunk.usage.prompt_tokens or 0,
completion_tokens=chunk.usage.completion_tokens or 0,
total_tokens=chunk.usage.total_tokens or 0,
)

async def _on_complete(usage_data: CompletionUsage) -> None:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/chat/completions",
user_id=user_id,
usage_override=usage_data,
)

async def _on_error(error: str) -> None:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/chat/completions",
user_id=user_id,
error=error,
)

stream: AsyncIterator[ChatCompletionChunk] = await acompletion(**completion_kwargs) # type: ignore[assignment]
rl_headers = rate_limit_headers(rate_limit_info) if rate_limit_info else {}
return StreamingResponse(generate(), media_type="text/event-stream", headers=rl_headers)
return StreamingResponse(
streaming_generator(
stream=stream,
format_chunk=_format_chunk,
extract_usage=_extract_usage,
fmt=OPENAI_STREAM_FORMAT,
on_complete=_on_complete,
on_error=_on_error,
label=f"{provider}:{model}",
),
media_type="text/event-stream",
headers=rl_headers,
)

completion: ChatCompletion = await acompletion(**completion_kwargs) # type: ignore[assignment]
await log_usage(
Expand Down
106 changes: 52 additions & 54 deletions src/any_llm/gateway/routes/messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import json
from collections.abc import AsyncIterator
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
Expand All @@ -16,8 +14,9 @@
from any_llm.gateway.log_config import logger
from any_llm.gateway.rate_limit import check_rate_limit
from any_llm.gateway.routes.chat import get_provider_kwargs, log_usage, rate_limit_headers
from any_llm.gateway.streaming import ANTHROPIC_STREAM_FORMAT, streaming_generator
from any_llm.types.completion import CompletionUsage
from any_llm.types.messages import MessageResponse, MessageStreamEvent # noqa: TC001
from any_llm.types.messages import MessageResponse, MessageStreamEvent

router = APIRouter(prefix="/v1", tags=["messages"])

Expand Down Expand Up @@ -112,58 +111,57 @@ async def create_message(
if request.stream:
call_kwargs["stream"] = True

async def generate() -> AsyncIterator[str]:
input_tokens = 0
output_tokens = 0

try:
stream: AsyncIterator[MessageStreamEvent] = await amessages(**call_kwargs) # type: ignore[assignment]
async for event in stream:
if event.usage:
if event.usage.input_tokens:
input_tokens = event.usage.input_tokens
if event.usage.output_tokens:
output_tokens = event.usage.output_tokens
yield f"event: {event.type}\ndata: {event.model_dump_json(exclude_none=True)}\n\n"
yield "event: done\ndata: {}\n\n"

if input_tokens or output_tokens:
usage_data = CompletionUsage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/messages",
user_id=user_id,
usage_override=usage_data,
)
except Exception as e:
error_data = {
"type": "error",
"error": {"type": "api_error", "message": "An error occurred during streaming"},
}
yield f"event: error\ndata: {json.dumps(error_data)}\n\n"
try:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/messages",
user_id=user_id,
error=str(e),
)
except Exception as log_err:
logger.error(f"Failed to log streaming error usage: {log_err}")
logger.error(f"Streaming error for {provider}:{model}: {e}")

def _format_chunk(event: MessageStreamEvent) -> str:
return f"event: {event.type}\ndata: {event.model_dump_json(exclude_none=True)}\n\n"

def _extract_usage(event: MessageStreamEvent) -> CompletionUsage | None:
if not event.usage:
return None
input_tokens = event.usage.input_tokens or 0
output_tokens = event.usage.output_tokens or 0
return CompletionUsage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
)

async def _on_complete(usage_data: CompletionUsage) -> None:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/messages",
user_id=user_id,
usage_override=usage_data,
)

async def _on_error(error: str) -> None:
await log_usage(
db=db,
api_key_obj=api_key,
model=model,
provider=provider,
endpoint="/v1/messages",
user_id=user_id,
error=error,
)

msg_stream = await amessages(**call_kwargs)
rl_headers = rate_limit_headers(rate_limit_info) if rate_limit_info else {}
return StreamingResponse(generate(), media_type="text/event-stream", headers=rl_headers)
return StreamingResponse(
streaming_generator(
stream=msg_stream, # type: ignore[arg-type]
format_chunk=_format_chunk,
extract_usage=_extract_usage,
fmt=ANTHROPIC_STREAM_FORMAT,
on_complete=_on_complete,
on_error=_on_error,
label=f"{provider}:{model}",
),
media_type="text/event-stream",
headers=rl_headers,
)

result: MessageResponse = await amessages(**call_kwargs) # type: ignore[assignment]

Expand Down
91 changes: 91 additions & 0 deletions src/any_llm/gateway/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Shared SSE streaming utilities for gateway routes."""

import json
from collections.abc import AsyncIterator, Awaitable, Callable
from dataclasses import dataclass
from typing import Any

from any_llm.gateway.log_config import logger
from any_llm.types.completion import CompletionUsage


@dataclass(frozen=True)
class StreamFormat:
"""SSE formatting configuration for a streaming protocol."""

done_marker: str
error_payload: str
yield_done_on_error: bool


_OPENAI_ERROR = json.dumps({"error": {"message": "An error occurred during streaming", "type": "server_error"}})
_ANTHROPIC_ERROR = json.dumps(
{"type": "error", "error": {"type": "api_error", "message": "An error occurred during streaming"}}
)

OPENAI_STREAM_FORMAT = StreamFormat(
done_marker="data: [DONE]\n\n",
error_payload=f"data: {_OPENAI_ERROR}\n\n",
yield_done_on_error=True,
)

ANTHROPIC_STREAM_FORMAT = StreamFormat(
done_marker="event: done\ndata: {}\n\n",
error_payload=f"event: error\ndata: {_ANTHROPIC_ERROR}\n\n",
yield_done_on_error=False,
)


def _merge_usage(current: CompletionUsage, update: CompletionUsage) -> CompletionUsage:
"""Merge usage data, keeping the last non-zero value for each field."""
return CompletionUsage(
prompt_tokens=update.prompt_tokens or current.prompt_tokens,
completion_tokens=update.completion_tokens or current.completion_tokens,
total_tokens=update.total_tokens or current.total_tokens,
)


async def streaming_generator(
stream: AsyncIterator[Any],
format_chunk: Callable[[Any], str],
extract_usage: Callable[[Any], CompletionUsage | None],
fmt: StreamFormat,
on_complete: Callable[[CompletionUsage], Awaitable[None]],
on_error: Callable[[str], Awaitable[None]],
label: str,
) -> AsyncIterator[str]:
"""Shared SSE streaming generator with usage tracking and error handling.

Args:
stream: Async iterator of chunks from the provider
format_chunk: Formats a chunk into an SSE string
extract_usage: Extracts usage from a chunk, or returns None if no usage present
fmt: SSE format configuration (done marker, error payload, etc.)
on_complete: Called with aggregated usage after successful streaming
on_error: Called with error message on failure
label: Identifier for error log messages (e.g., "openai:gpt-4")

"""
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
has_usage = False

try:
async for chunk in stream:
chunk_usage = extract_usage(chunk)
if chunk_usage:
usage = _merge_usage(usage, chunk_usage)
has_usage = True
yield format_chunk(chunk)
yield fmt.done_marker

if has_usage:
await on_complete(usage)
except Exception as e:
yield fmt.error_payload
if fmt.yield_done_on_error:
yield fmt.done_marker
try:
await on_error(str(e))
except Exception as log_err:
logger.error("Failed to log streaming error usage: %s", log_err)
logger.error("Streaming error for %s: %s", label, e)
Loading
Loading