Skip to content

Commit 748c689

Browse files
committed
init upgrade resolve issue
1 parent 8295513 commit 748c689

File tree

26 files changed

+2059
-1372
lines changed

26 files changed

+2059
-1372
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,24 @@
11
from __future__ import annotations
22

3+
import os
4+
from logging import getLogger
5+
36
from openai.types.chat import (
47
ChatCompletion,
58
ChatCompletionMessageParam,
9+
ChatCompletionToolChoiceOptionParam,
10+
ChatCompletionToolParam,
611
completion_create_params,
712
)
813
from typing_extensions import Dict, Iterable, List, Optional, Union
914

10-
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
11-
from patchwork.logger import logger
15+
from ...constants import DEFAULT_PATCH_URL
16+
from .anthropic import AnthropicLlmClient
17+
from .google import GoogleLlmClient
18+
from .openai_ import OpenAiLlmClient
19+
from .protocol import NOT_GIVEN, LlmClient, NotGiven
20+
21+
logger = getLogger(__name__)
1222

1323

1424
class AioLlmClient(LlmClient):
@@ -29,10 +39,43 @@ def get_models(self) -> set[str]:
2939
def is_model_supported(self, model: str) -> bool:
3040
return any(client.is_model_supported(model) for client in self.__clients)
3141

32-
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
42+
def is_prompt_supported(
43+
self,
44+
messages: Iterable[ChatCompletionMessageParam],
45+
model: str,
46+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
47+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
48+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
49+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
50+
n: Optional[int] | NotGiven = NOT_GIVEN,
51+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
52+
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
53+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
54+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
55+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
56+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
57+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
58+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
59+
) -> int:
3360
for client in self.__clients:
3461
if client.is_model_supported(model):
35-
return client.is_prompt_supported(messages, model)
62+
return client.is_prompt_supported(
63+
messages=messages,
64+
model=model,
65+
frequency_penalty=frequency_penalty,
66+
logit_bias=logit_bias,
67+
logprobs=logprobs,
68+
max_tokens=max_tokens,
69+
n=n,
70+
presence_penalty=presence_penalty,
71+
response_format=response_format,
72+
stop=stop,
73+
temperature=temperature,
74+
tools=tools,
75+
tool_choice=tool_choice,
76+
top_logprobs=top_logprobs,
77+
top_p=top_p,
78+
)
3679
return -1
3780

3881
def truncate_messages(
@@ -56,6 +99,8 @@ def chat_completion(
5699
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
57100
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
58101
temperature: Optional[float] | NotGiven = NOT_GIVEN,
102+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
103+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
59104
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
60105
top_p: Optional[float] | NotGiven = NOT_GIVEN,
61106
) -> ChatCompletion:
@@ -74,6 +119,8 @@ def chat_completion(
74119
response_format,
75120
stop,
76121
temperature,
122+
tools,
123+
tool_choice,
77124
top_logprobs,
78125
top_p,
79126
)
@@ -82,3 +129,33 @@ def chat_completion(
82129
f"Model {model} is not supported by {client_names} clients. "
83130
f"Please ensure that the respective API keys are correct."
84131
)
132+
133+
@staticmethod
134+
def create_aio_client(inputs) -> "AioLlmClient" | None:
135+
clients = []
136+
137+
patched_key = inputs.get("patched_api_key")
138+
if patched_key is not None:
139+
client = OpenAiLlmClient(patched_key, DEFAULT_PATCH_URL)
140+
clients.append(client)
141+
142+
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
143+
if openai_key is not None:
144+
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
145+
client = OpenAiLlmClient(openai_key, **client_args)
146+
clients.append(client)
147+
148+
google_key = inputs.get("google_api_key")
149+
if google_key is not None:
150+
client = GoogleLlmClient(google_key)
151+
clients.append(client)
152+
153+
anthropic_key = inputs.get("anthropic_api_key")
154+
if anthropic_key is not None:
155+
client = AnthropicLlmClient(anthropic_key)
156+
clients.append(client)
157+
158+
if len(clients) == 0:
159+
return None
160+
161+
return AioLlmClient(*clients)

patchwork/common/client/llm/anthropic.py

Lines changed: 175 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from functools import lru_cache
66

77
from anthropic import Anthropic
8-
from anthropic.types import Message, TextBlockParam
8+
from anthropic.types import Message, MessageParam, TextBlockParam
99
from openai.types.chat import (
1010
ChatCompletion,
1111
ChatCompletionMessage,
1212
ChatCompletionMessageParam,
13+
ChatCompletionToolChoiceOptionParam,
14+
ChatCompletionToolParam,
1315
completion_create_params,
1416
)
1517
from openai.types.chat.chat_completion import Choice, CompletionUsage
@@ -20,7 +22,7 @@
2022
from openai.types.completion_usage import CompletionUsage
2123
from typing_extensions import Dict, Iterable, List, Optional, Union
2224

23-
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
25+
from .protocol import NOT_GIVEN, LlmClient, NotGiven
2426

2527

2628
def _anthropic_to_openai_response(model: str, anthropic_response: Message) -> ChatCompletion:
@@ -81,23 +83,169 @@ def __get_model_limit(self, model: str) -> int:
8183
return 100_000 - safety_margin
8284
return 200_000 - safety_margin
8385

86+
def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) -> list[MessageParam]:
87+
new_messages = []
88+
for message in messages:
89+
if message.get("role") == "system":
90+
if system is NOT_GIVEN:
91+
system = list()
92+
system.append(TextBlockParam(text=message.get("content"), type="text"))
93+
elif message.get("role") == "tool":
94+
new_messages.append(
95+
dict(
96+
role="user",
97+
content=[
98+
dict(
99+
type="tool_result",
100+
tool_use_id=message.get("tool_call_id"),
101+
content=message.get("content"),
102+
)
103+
],
104+
)
105+
)
106+
elif message.get("role") == "assistant" and len(message.get("tool_calls", [])) > 0:
107+
tool_calls = message["tool_calls"]
108+
tool_calls_as_content = [
109+
dict(
110+
type="tool_use",
111+
id=tool_call["id"],
112+
name=tool_call["function"]["name"],
113+
input=json.loads(tool_call["function"]["arguments"]),
114+
)
115+
for tool_call in tool_calls
116+
]
117+
new_messages.append(
118+
dict(
119+
role="assistant",
120+
content=[
121+
*tool_calls_as_content,
122+
],
123+
)
124+
)
125+
else:
126+
new_messages.append(message)
127+
128+
return new_messages
129+
130+
def __adapt_chat_completion_request(
131+
self,
132+
messages: Iterable[ChatCompletionMessageParam],
133+
model: str,
134+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
135+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
136+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
137+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
138+
n: Optional[int] | NotGiven = NOT_GIVEN,
139+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
140+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
141+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
142+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
143+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
144+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
145+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
146+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
147+
):
148+
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
149+
adapted_messages = self.__adapt_input_messages(messages)
150+
default_max_token = 1000
151+
152+
if tool_choice is not NOT_GIVEN:
153+
# openai tool choice to anthropic tool choice mapping:
154+
# openai : none, auto, required , required
155+
# anthropic: NA , auto, any , tool
156+
if isinstance(tool_choice, str):
157+
if tool_choice == "required":
158+
tool_choice = dict(type="any")
159+
elif tool_choice == "none":
160+
tool_choice = NOT_GIVEN
161+
else:
162+
tool_choice = dict(type=tool_choice)
163+
else:
164+
tool_choice_type = tool_choice.get("type")
165+
if tool_choice_type == "required":
166+
if tool_choice.get("function") is not None:
167+
tool_choice["type"] = "tool"
168+
tool_choice["name"] = tool_choice["function"]["name"]
169+
else:
170+
tool_choice["type"] = "any"
171+
elif tool_choice_type == "none":
172+
tool_choice = NOT_GIVEN
173+
174+
input_kwargs = dict(
175+
messages=adapted_messages,
176+
system=system,
177+
max_tokens=default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens,
178+
model=model,
179+
stop_sequences=[stop] if isinstance(stop, str) else stop,
180+
temperature=temperature,
181+
tools=[tool.get("function") for tool in tools if tool.get("function") is not None],
182+
tool_choice=tool_choice,
183+
top_p=top_p,
184+
)
185+
186+
if response_format is not NOT_GIVEN and response_format.get("type") == "json_schema":
187+
input_kwargs["tool_choice"] = dict(type="tool", name="response_format")
188+
if input_kwargs.get("tools") is NOT_GIVEN:
189+
input_kwargs["tools"] = list()
190+
response_format_tool = dict(
191+
name="response_format",
192+
description="The response format to use",
193+
input_schema=response_format["json_schema"]["schema"],
194+
)
195+
input_kwargs["tools"] = [*input_kwargs["tools"], response_format_tool]
196+
197+
return NotGiven.remove_not_given(input_kwargs)
198+
84199
@lru_cache(maxsize=None)
85200
def get_models(self) -> set[str]:
86201
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
87202

88203
def is_model_supported(self, model: str) -> bool:
89204
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)
90205

91-
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
206+
def is_prompt_supported(
207+
self,
208+
messages: Iterable[ChatCompletionMessageParam],
209+
model: str,
210+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
211+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
212+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
213+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
214+
n: Optional[int] | NotGiven = NOT_GIVEN,
215+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
216+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
217+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
218+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
219+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
220+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
221+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
222+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
223+
) -> int:
92224
model_limit = self.__get_model_limit(model)
93-
token_count = 0
94-
for message in messages:
95-
message_token_count = self.client.count_tokens(message.get("content"))
96-
token_count = token_count + message_token_count
97-
if token_count > model_limit:
98-
return -1
99-
100-
return model_limit - token_count
225+
input_kwargs = self.__adapt_chat_completion_request(
226+
messages=messages,
227+
model=model,
228+
frequency_penalty=frequency_penalty,
229+
logit_bias=logit_bias,
230+
logprobs=logprobs,
231+
max_tokens=max_tokens,
232+
n=n,
233+
presence_penalty=presence_penalty,
234+
response_format=response_format,
235+
stop=stop,
236+
temperature=temperature,
237+
tools=tools,
238+
tool_choice=tool_choice,
239+
top_logprobs=top_logprobs,
240+
top_p=top_p,
241+
)
242+
count_token_input_kwargs = {
243+
k: v
244+
for k, v in input_kwargs.items()
245+
if k in {"messages", "model", "system", "tool_choice", "tools", "beta"}
246+
}
247+
message_token_count = self.client.beta.messages.count_tokens(**count_token_input_kwargs)
248+
return model_limit - message_token_count.input_tokens
101249

102250
def truncate_messages(
103251
self, messages: Iterable[ChatCompletionMessageParam], model: str
@@ -117,38 +265,28 @@ def chat_completion(
117265
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
118266
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
119267
temperature: Optional[float] | NotGiven = NOT_GIVEN,
268+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
269+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
120270
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
121271
top_p: Optional[float] | NotGiven = NOT_GIVEN,
122272
) -> ChatCompletion:
123-
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
124-
other_messages = []
125-
for message in messages:
126-
if message.get("role") == "system":
127-
if system is NOT_GIVEN:
128-
system = list()
129-
system.append(TextBlockParam(text=message.get("content"), type="text"))
130-
else:
131-
other_messages.append(message)
132-
133-
default_max_token = 1000
134-
input_kwargs = dict(
135-
messages=other_messages,
136-
system=system,
137-
max_tokens=default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens,
273+
input_kwargs = self.__adapt_chat_completion_request(
274+
messages=messages,
138275
model=model,
139-
stop_sequences=[stop] if isinstance(stop, str) else stop,
276+
frequency_penalty=frequency_penalty,
277+
logit_bias=logit_bias,
278+
logprobs=logprobs,
279+
max_tokens=max_tokens,
280+
n=n,
281+
presence_penalty=presence_penalty,
282+
response_format=response_format,
283+
stop=stop,
140284
temperature=temperature,
285+
tools=tools,
286+
tool_choice=tool_choice,
287+
top_logprobs=top_logprobs,
141288
top_p=top_p,
142289
)
143-
if response_format is not NOT_GIVEN and response_format.get("type") == "json_schema":
144-
input_kwargs["tool_choice"] = dict(type="tool", name="response_format")
145-
input_kwargs["tools"] = [
146-
dict(
147-
name="response_format",
148-
description="The response format to use",
149-
input_schema=response_format["json_schema"]["schema"],
150-
)
151-
]
152290

153-
response = self.client.messages.create(**NotGiven.remove_not_given(input_kwargs))
291+
response = self.client.messages.create(**input_kwargs)
154292
return _anthropic_to_openai_response(model, response)

patchwork/common/client/llm/google.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
from openai.types.chat.chat_completion import ChatCompletion, Choice
2222
from typing_extensions import Any, Dict, Iterable, List, Optional, Union
2323

24-
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
25-
from patchwork.common.client.llm.utils import json_schema_to_model
24+
from .protocol import NOT_GIVEN, LlmClient, NotGiven
25+
from .utils import json_schema_to_model
2626

2727

2828
@functools.lru_cache

0 commit comments

Comments
 (0)