Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@

import aidial_adapter_openai.endpoints as endpoints
from aidial_adapter_openai.configuration.app_config import ApplicationConfig
from aidial_adapter_openai.exception_handlers import (
from aidial_adapter_openai.exceptions.handlers import (
adapter_exception_handler,
fastapi_exception_handler,
)
from aidial_adapter_openai.utils.http_client import get_http_client
from aidial_adapter_openai.utils.auth import close_azure_credential
from aidial_adapter_openai.utils.http_client import (
get_anthropic_httpx_client,
get_http_client,
)
from aidial_adapter_openai.utils.log_config import configure_loggers, logger
from aidial_adapter_openai.utils.request import set_app_config

Expand All @@ -23,6 +27,8 @@ async def lifespan(app: FastAPI):
yield
logger.info("Application shutdown")
await get_http_client().aclose()
await get_anthropic_httpx_client().aclose()
await close_azure_credential()


def create_app(
Expand Down
72 changes: 72 additions & 0 deletions aidial_adapter_openai/chat_completions/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

import fastapi
from aidial_adapter_anthropic.adapter import ChatCompletionAdapter, UserError
from aidial_adapter_anthropic.adapter.claude import ApproximateTokenizer
from aidial_adapter_anthropic.adapter.claude import (
create_adapter as create_anthropic_adapter,
)
from aidial_adapter_anthropic.dial.consumer import ChoiceConsumer
from aidial_adapter_anthropic.dial.request import ModelParameters
from aidial_adapter_anthropic.dial.storage import FileStorage
from aidial_sdk.chat_completion import Request as DIALRequest
from aidial_sdk.chat_completion import Response as DIALResponse
from anthropic import AsyncAnthropicFoundry
from fastapi.responses import StreamingResponse

from aidial_adapter_openai.dial_api.sdk_adapter import sdk_adapter
from aidial_adapter_openai.utils.env import get_env_int

DIAL_URL = os.getenv("DIAL_URL")


def _create_file_storage(api_key: str | None) -> FileStorage | None:
if api_key is None or DIAL_URL is None:
return None

return FileStorage(dial_url=DIAL_URL, api_key=api_key)


CLAUDE_DEFAULT_MAX_TOKENS = get_env_int("CLAUDE_DEFAULT_MAX_TOKENS", 1536)


async def create_adapter(
deployment: str, api_key: str, client: AsyncAnthropicFoundry
) -> ChatCompletionAdapter:
return await create_anthropic_adapter(
deployment=deployment,
storage=_create_file_storage(api_key),
client=client,
custom_tokenizer=ApproximateTokenizer(),
default_max_tokens=CLAUDE_DEFAULT_MAX_TOKENS,
supports_thinking=True,
supports_documents=True,
)


async def chat_completion(
*,
request: fastapi.Request,
deployment_id: str,
client: AsyncAnthropicFoundry,
) -> StreamingResponse | dict:

async def _handler(request: DIALRequest, response: DIALResponse) -> None:
model = await create_adapter(deployment_id, request.api_key, client)
response.set_model(deployment_id)

params = ModelParameters.create(request)

with ChoiceConsumer(response) as consumer:
try:
await model.chat(consumer, params, request.messages)
except UserError as e:
await e.report_usage(consumer.choice)
await response.aflush()
raise e

return await sdk_adapter(
request=request,
deployment_id=deployment_id,
chat_completion=_handler,
)
11 changes: 10 additions & 1 deletion aidial_adapter_openai/configuration/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
)
from aidial_adapter_openai.utils.json import remove_nones
from aidial_adapter_openai.utils.parsers import (
AnthropicEndpoint,
AzureOpenAIEndpoint,
OpenAIEndpoint,
anthropic_messages_parser,
azure_video_api_parser,
chat_completions_parser,
completions_parser,
Expand All @@ -35,7 +37,7 @@

class DeploymentAPIType(ExtraForbidModel):
deployment_type: D
endpoint: AzureOpenAIEndpoint | OpenAIEndpoint
endpoint: AzureOpenAIEndpoint | OpenAIEndpoint | AnthropicEndpoint


class ApplicationConfig(ExtraForbidModel):
Expand Down Expand Up @@ -63,6 +65,12 @@ class ApplicationConfig(ExtraForbidModel):
def get_chat_completion_deployment_type(
self, deployment_id: str, upstream_endpoint: str
) -> DeploymentAPIType:
if endpoint := anthropic_messages_parser.try_parse(upstream_endpoint):
return DeploymentAPIType(
deployment_type=D.ANTHROPIC_MESSAGES_API,
endpoint=endpoint,
)

if endpoint := completions_parser.try_parse(upstream_endpoint):
return DeploymentAPIType(
deployment_type=D.COMPLETIONS_API,
Expand Down Expand Up @@ -171,6 +179,7 @@ def add_deployment(
| D.AZURE_VIDEO_API
| D.AUDIO_SPEECH_API
| D.AUDIO_TRANSCRIPTIONS_API
| D.ANTHROPIC_MESSAGES_API
):
pass
case _:
Expand Down
1 change: 1 addition & 0 deletions aidial_adapter_openai/configuration/deployment_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ class ChatCompletionDeploymentType(StrEnum):
AZURE_VIDEO_API = "VIDEO_API"
AUDIO_SPEECH_API = "SPEECH_API"
AUDIO_TRANSCRIPTIONS_API = "TRANSCRIPTIONS_API"
ANTHROPIC_MESSAGES_API = "ANTHROPIC_MESSAGES_API"
2 changes: 1 addition & 1 deletion aidial_adapter_openai/dial_api/sdk_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from fastapi.responses import StreamingResponse

from aidial_adapter_openai.dial_api.storage import DIAL_URL
from aidial_adapter_openai.exception_handlers import dial_exception_decorator
from aidial_adapter_openai.exceptions.handlers import dial_exception_decorator


async def sdk_adapter(
Expand Down
16 changes: 16 additions & 0 deletions aidial_adapter_openai/endpoints/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Mapping, assert_never

import fastapi
from anthropic import AsyncAnthropicFoundry
from fastapi import Request
from openai import AsyncAzureOpenAI

Expand All @@ -10,6 +11,9 @@
from aidial_adapter_openai.audio_api.transcribe.adapter import (
chat_completion as audio_transcriptions_gen,
)
from aidial_adapter_openai.chat_completions.anthropic import (
chat_completion as anthropic_chat_completions,
)
from aidial_adapter_openai.chat_completions.gpt import (
chat_completion as gpt_chat_completion,
)
Expand Down Expand Up @@ -95,6 +99,13 @@ def _get_tokenizer() -> Tokenizer:
image_tokenizer = get_image_tokenizer(deployment_type)
return Tokenizer(model=tiktoken_model, image_tokenizer=image_tokenizer)

if isinstance(client, AsyncAnthropicFoundry):
return await anthropic_chat_completions(
request=request,
client=client,
deployment_id=deployment_id,
)

match deployment_type:
case D.COMPLETIONS_API:
templates = app_config.COMPLETION_DEPLOYMENTS_PROMPT_TEMPLATES
Expand Down Expand Up @@ -183,6 +194,11 @@ def _get_tokenizer() -> Tokenizer:
response.body = extract_reasoning_tokens(response.body)
return response

case D.ANTHROPIC_MESSAGES_API:
raise RuntimeError(
"Anthropic API endpoint must have resulted in Anthropic client"
)

case _:
assert_never(deployment_type)

Expand Down
79 changes: 79 additions & 0 deletions aidial_adapter_openai/exceptions/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import re

from aidial_sdk.exceptions import HTTPException as DialException
from openai import APIStatusError

from aidial_adapter_openai.utils.adapter_exception import AdapterException


def _create_error(
status_code: int, message: str, headers: dict[str, str] | None = None
) -> DialException:
return DialException(
status_code=status_code,
type=_get_exception_type(status_code),
message=message,
headers=headers,
)


def _get_exception_type(status_code: int) -> str | None:
if status_code in {400, 422}:
return "invalid_request_error"
if status_code == 500:
return "internal_server_error"
return None


def _get_error_message(e: APIStatusError) -> str:
if isinstance(body := e.body, dict):
if isinstance((msg := body.get("message")), str):
return msg
return e.message


def _parse_streaming_error(text: str) -> DialException | None:
# Unfortunately, anthropic SDK obscures the original error message:
# https://github.com/anthropics/anthropic-sdk-python/blob/8b244157a7d03766bec645b0e1dc213c6d462165/src/anthropic/lib/bedrock/_stream_decoder.py#L57-L58
# So we have to parse it manually.

prefix = "Bad response code, expected 200: "
if not text.startswith(prefix):
return None
text = text.removeprefix(prefix)

code_pattern = re.search(r"'status_code':\s*(\d+)", text)
message_pattern = re.search(r"\"message\":\s*\"(.*?)\"", text)

code = int(code_pattern.group(1)) if code_pattern else None
message = str(message_pattern.group(1)) if message_pattern else None

if code and message:
message = message.replace("\\'", "'")
return _create_error(code, message)
return None


def _copy_headers(e: APIStatusError, keys: list[str]) -> dict[str, str] | None:
copied_headers: dict[str, str] = {}
for key in keys:
if key in e.response.headers:
copied_headers[key] = e.response.headers[key]
return copied_headers or None


def convert_anthropic_errors(e: Exception) -> AdapterException | None:
if isinstance(e, APIStatusError):
message = _get_error_message(e)
# We want to save Retry-After header if it's present:
# https://platform.claude.com/docs/en/api/rate-limits#tier-1

headers = _copy_headers(e, ["Retry-After"])
return _create_error(e.status_code, message, headers)

if isinstance(e, ValueError):
exc = _parse_streaming_error(str(e))
if exc is not None:
return exc

return None
30 changes: 30 additions & 0 deletions aidial_adapter_openai/exceptions/application.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import httpx
from aidial_adapter_anthropic.adapter import UserError, ValidationError
from aidial_sdk.exceptions import HTTPException as DialException

from aidial_adapter_openai.utils.adapter_exception import (
AdapterException,
parse_adapter_exception,
)


def convert_application_errors(e: Exception) -> AdapterException | None:
if isinstance(e, httpx.HTTPStatusError):
r = e.response
if ret := parse_adapter_exception(
status_code=r.status_code,
headers={},
content=r.text,
):
return ret

if isinstance(e, ValidationError):
return e.to_dial_exception()

if isinstance(e, UserError):
return e.to_dial_exception()

if isinstance(e, DialException):
return e

return None
Loading
Loading