Skip to content

Commit 7b9e4d5

Browse files
committed
chore: refactor into a base with utility functions
1 parent b8df34b commit 7b9e4d5

File tree

10 files changed

+992
-1052
lines changed

10 files changed

+992
-1052
lines changed
Lines changed: 163 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import os
21
import json
3-
from typing import Any
2+
from typing import Any, cast
43

54
try:
65
from anthropic import Anthropic
@@ -11,205 +10,180 @@
1110

1211
from openai.types.chat.chat_completion import ChatCompletion, Choice
1312
from openai.types.completion_usage import CompletionUsage
14-
from openai.types.chat.chat_completion_message import ChatCompletionMessage
15-
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
16-
from any_llm.provider import Provider, ApiConfig
17-
from any_llm.exceptions import MissingApiKeyError
1813

19-
# Define a constant for the default max_tokens value
14+
from any_llm.provider import ApiConfig
15+
from any_llm.providers.base_framework import (
16+
BaseCustomProvider,
17+
create_openai_tool_call,
18+
create_openai_message,
19+
create_openai_completion,
20+
convert_openai_tools_to_generic,
21+
extract_system_message,
22+
remove_unsupported_params,
23+
)
24+
2025
DEFAULT_MAX_TOKENS = 4096
2126

2227

23-
def _convert_kwargs(kwargs: dict[str, Any]) -> dict[str, Any]:
24-
"""Format the kwargs for Anthropic."""
25-
kwargs = kwargs.copy()
26-
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
27-
28-
# Convert tools if present
29-
if "tools" in kwargs:
30-
kwargs["tools"] = _convert_tool_spec(kwargs["tools"])
31-
32-
# Handle parallel_tool_calls parameter
33-
if "parallel_tool_calls" in kwargs:
34-
parallel_tool_calls = kwargs.pop("parallel_tool_calls")
35-
# If parallel_tool_calls is False, set disable_parallel_tool_use to True
36-
if parallel_tool_calls is False:
37-
tool_choice = {"type": kwargs.get("tool_choice", "any"), "disable_parallel_tool_use": True}
38-
kwargs["tool_choice"] = tool_choice
39-
# If parallel_tool_calls is True or not specified, don't set disable_parallel_tool_use
40-
# (Anthropic defaults to parallel tool use enabled)
41-
42-
if "response_format" in kwargs:
43-
error_msg = (
44-
"response_format is not supported for Anthropic, see their documentation "
45-
"for tips on how to achieve structured output: "
46-
"https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/increase-consistency#example-standardizing-customer-feedback"
47-
)
48-
raise ValueError(error_msg)
28+
class AnthropicProvider(BaseCustomProvider):
29+
"""
30+
Anthropic Provider using enhanced BaseCustomProvider framework.
4931
50-
return kwargs
32+
Handles conversion between OpenAI format and Anthropic's native format.
33+
"""
5134

35+
PROVIDER_NAME = "Anthropic"
36+
ENV_API_KEY_NAME = "ANTHROPIC_API_KEY"
5237

53-
def _convert_tool_spec(openai_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
54-
"""Convert OpenAI tool specification to Anthropic format."""
55-
anthropic_tools = []
38+
def _initialize_client(self, config: ApiConfig) -> None:
39+
"""Initialize the Anthropic client."""
40+
self.client = Anthropic(api_key=config.api_key, base_url=config.api_base)
5641

57-
for tool in openai_tools:
58-
if tool.get("type") != "function":
59-
continue
42+
def _convert_kwargs(self, kwargs: dict[str, Any]) -> dict[str, Any]:
43+
"""Convert kwargs to Anthropic format."""
44+
kwargs = kwargs.copy()
45+
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)
46+
47+
# Remove unsupported parameters
48+
kwargs = remove_unsupported_params(kwargs, ["response_format"])
49+
50+
# Convert tools if present
51+
if "tools" in kwargs:
52+
kwargs["tools"] = self._convert_tool_spec(kwargs["tools"])
53+
54+
# Handle parallel_tool_calls
55+
if "parallel_tool_calls" in kwargs:
56+
parallel_tool_calls = kwargs.pop("parallel_tool_calls")
57+
if parallel_tool_calls is False:
58+
tool_choice = {"type": kwargs.get("tool_choice", "any"), "disable_parallel_tool_use": True}
59+
kwargs["tool_choice"] = tool_choice
60+
61+
return kwargs
62+
63+
def _convert_messages(self, messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
64+
"""Convert messages to Anthropic format, extracting system message."""
65+
# Extract system message using the utility
66+
system_message, remaining_messages = extract_system_message(messages)
67+
68+
converted_messages = []
69+
for message in remaining_messages:
70+
if message["role"] == "tool":
71+
converted_message = {
72+
"role": "user",
73+
"content": [
74+
{
75+
"type": "tool_result",
76+
"tool_use_id": message["tool_call_id"],
77+
"content": message["content"],
78+
}
79+
],
80+
}
81+
converted_messages.append(converted_message)
82+
elif message["role"] == "assistant" and "tool_calls" in message:
83+
message_content = []
84+
if message.get("content"):
85+
message_content.append({"type": "text", "text": message["content"]})
86+
87+
for tool_call in message.get("tool_calls") or []:
88+
message_content.append(
89+
{
90+
"type": "tool_use",
91+
"id": tool_call["id"],
92+
"name": tool_call["function"]["name"],
93+
"input": json.loads(tool_call["function"]["arguments"]),
94+
}
95+
)
96+
97+
converted_message = {"role": "assistant", "content": message_content}
98+
converted_messages.append(converted_message)
99+
else:
100+
converted_message = {"role": message["role"], "content": message["content"]}
101+
converted_messages.append(converted_message)
102+
103+
return system_message, converted_messages
104+
105+
def _make_api_call(self, model: str, messages: tuple[str, list[dict[str, Any]]], **kwargs: Any) -> Message:
106+
"""Make the API call to Anthropic."""
107+
system_message, converted_messages = messages
108+
109+
return self.client.messages.create(
110+
model=model,
111+
system=system_message,
112+
messages=converted_messages, # type: ignore[arg-type]
113+
**kwargs,
114+
)
60115

61-
function = tool["function"]
62-
anthropic_tool = {
63-
"name": function["name"],
64-
"description": function["description"],
65-
"input_schema": {
66-
"type": "object",
67-
"properties": function["parameters"]["properties"],
68-
"required": function["parameters"].get("required", []),
69-
},
116+
def _convert_response(self, response: Message) -> ChatCompletion:
117+
"""Convert Anthropic response to OpenAI format."""
118+
finish_reason_mapping = {
119+
"end_turn": "stop",
120+
"max_tokens": "length",
121+
"tool_use": "tool_calls",
70122
}
71-
anthropic_tools.append(anthropic_tool)
72-
73-
return anthropic_tools
74-
75-
76-
def _convert_messages(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
77-
"""Convert messages to Anthropic format, extracting system message."""
78-
system_message = ""
79-
converted_messages = []
80-
81-
for message in messages:
82-
if message["role"] == "system":
83-
system_message = message["content"]
84-
continue
85-
elif message["role"] == "tool":
86-
# Convert tool message to Anthropic format
87-
converted_message = {
88-
"role": "user",
89-
"content": [
90-
{
91-
"type": "tool_result",
92-
"tool_use_id": message["tool_call_id"],
93-
"content": message["content"],
94-
}
95-
],
96-
}
97-
converted_messages.append(converted_message)
98-
elif message["role"] == "assistant" and "tool_calls" in message:
99-
# Convert assistant message with tool calls
100-
message_content = []
101-
if message.get("content"):
102-
message_content.append({"type": "text", "text": message["content"]})
103-
104-
for tool_call in message.get("tool_calls") or []:
105-
message_content.append(
106-
{
107-
"type": "tool_use",
108-
"id": tool_call["id"],
109-
"name": tool_call["function"]["name"],
110-
"input": json.loads(tool_call["function"]["arguments"]),
111-
}
112-
)
113123

114-
converted_message = {"role": "assistant", "content": message_content}
115-
converted_messages.append(converted_message)
116-
else:
117-
# Regular message
118-
converted_message = {"role": message["role"], "content": message["content"]}
119-
converted_messages.append(converted_message)
120-
121-
return system_message, converted_messages
122-
123-
124-
def _convert_response(response: Message) -> ChatCompletion:
125-
"""Convert Anthropic response directly to OpenAI ChatCompletion format."""
126-
# Finish reason mapping
127-
finish_reason_mapping = {
128-
"end_turn": "stop",
129-
"max_tokens": "length",
130-
"tool_use": "tool_calls",
131-
}
132-
133-
# Process content blocks
134-
tool_calls = []
135-
content = ""
136-
137-
for content_block in response.content:
138-
if content_block.type == "text":
139-
content = content_block.text
140-
elif content_block.type == "tool_use":
141-
tool_calls.append(
142-
ChatCompletionMessageToolCall(
143-
id=content_block.id,
144-
type="function",
145-
function=Function(name=content_block.name, arguments=json.dumps(content_block.input)),
124+
# Process content blocks
125+
tool_calls = []
126+
content = ""
127+
128+
for content_block in response.content:
129+
if content_block.type == "text":
130+
content = content_block.text
131+
elif content_block.type == "tool_use":
132+
tool_calls.append(
133+
create_openai_tool_call(
134+
tool_call_id=content_block.id,
135+
name=content_block.name,
136+
arguments=json.dumps(content_block.input),
137+
)
146138
)
147-
)
148-
149-
# Create the message
150-
message = ChatCompletionMessage(
151-
content=content or None,
152-
role="assistant",
153-
tool_calls=tool_calls if tool_calls else None,
154-
)
155-
156-
# Create the choice
157-
if not response.stop_reason:
158-
response.stop_reason = "end_turn"
159-
mapped_finish_reason = finish_reason_mapping.get(response.stop_reason, "stop")
160-
choice = Choice(
161-
finish_reason=mapped_finish_reason, # type: ignore
162-
index=0,
163-
message=message,
164-
)
165-
166-
# Create usage information
167-
usage = CompletionUsage(
168-
completion_tokens=response.usage.output_tokens,
169-
prompt_tokens=response.usage.input_tokens,
170-
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
171-
)
172-
173-
# Build the final ChatCompletion object
174-
return ChatCompletion(
175-
id=response.id,
176-
model=response.model,
177-
object="chat.completion",
178-
created=int(response.created_at.timestamp()) if hasattr(response, "created_at") else 0,
179-
choices=[choice],
180-
usage=usage,
181-
)
182-
183-
184-
class AnthropicProvider(Provider):
185-
def __init__(self, config: ApiConfig) -> None:
186-
"""Initialize Anthropic provider."""
187-
if not config.api_key:
188-
config.api_key = os.getenv("ANTHROPIC_API_KEY")
189-
if not config.api_key:
190-
raise MissingApiKeyError(
191-
"Anthropic",
192-
"ANTHROPIC_API_KEY",
193-
)
194-
self.client = Anthropic(api_key=config.api_key, base_url=config.api_base)
195139

196-
def completion(
197-
self,
198-
model: str,
199-
messages: list[dict[str, Any]],
200-
**kwargs: Any,
201-
) -> ChatCompletion:
202-
"""Create a chat completion using Anthropic."""
203-
kwargs = _convert_kwargs(kwargs)
204-
system_message, converted_messages = _convert_messages(messages)
205-
206-
# Make the request to Anthropic
207-
response = self.client.messages.create(
208-
model=model,
209-
system=system_message,
210-
messages=converted_messages, # type: ignore
211-
**kwargs,
140+
# Create the message
141+
message = create_openai_message(
142+
role="assistant",
143+
content=content or None,
144+
tool_calls=tool_calls if tool_calls else None,
145+
)
146+
147+
# Create the choice
148+
mapped_finish_reason = finish_reason_mapping.get(response.stop_reason or "end_turn", "stop")
149+
choice = Choice(
150+
finish_reason=cast(Any, mapped_finish_reason),
151+
index=0,
152+
message=message,
153+
)
154+
155+
# Create usage information
156+
usage = CompletionUsage(
157+
completion_tokens=response.usage.output_tokens,
158+
prompt_tokens=response.usage.input_tokens,
159+
total_tokens=response.usage.input_tokens + response.usage.output_tokens,
160+
)
161+
162+
return create_openai_completion(
163+
id=response.id,
164+
model=response.model,
165+
choices=[choice],
166+
usage=usage,
167+
created=int(response.created_at.timestamp()) if hasattr(response, "created_at") else 0,
212168
)
213169

214-
# Convert to OpenAI format
215-
return _convert_response(response)
170+
def _convert_tool_spec(self, openai_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
171+
"""Convert OpenAI tool specification to Anthropic format."""
172+
# Use the generic utility first
173+
generic_tools = convert_openai_tools_to_generic(openai_tools)
174+
175+
# Convert to Anthropic-specific format
176+
anthropic_tools = []
177+
for tool in generic_tools:
178+
anthropic_tool = {
179+
"name": tool["name"],
180+
"description": tool["description"],
181+
"input_schema": {
182+
"type": "object",
183+
"properties": tool["parameters"]["properties"],
184+
"required": tool["parameters"].get("required", []),
185+
},
186+
}
187+
anthropic_tools.append(anthropic_tool)
188+
189+
return anthropic_tools

0 commit comments

Comments
 (0)