Skip to content

Commit e9b7ee5

Browse files
CTY-gitpatched-adminpatched.codes[bot]
authored
Upgrade ResolveIssue (#1040)
* init upgrade resolve issue * update * update lock and lint * refactor common module * enable * fix resolve issue * fix aio * update imports * lint * fix output type * fix tracing * switch to set * update set with set * Patched patchwork/steps/FixIssue/README.md (#1046) Co-authored-by: patched.codes[bot] <298395+patched.codes[bot]@users.noreply.github.com> * fix test --------- Co-authored-by: Patched <[email protected]> Co-authored-by: patched.codes[bot] <298395+patched.codes[bot]@users.noreply.github.com>
1 parent 8295513 commit e9b7ee5

File tree

31 files changed

+2221
-1390
lines changed

31 files changed

+2221
-1390
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,14 @@ jobs:
143143
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
144144
run: poetry install --no-interaction --only main --extras rag
145145

146-
- name: Propose relevant file to issues
147-
run: |
148-
poetry run patchwork ResolveIssue --log debug \
149-
--patched_api_key=${{ secrets.PATCHED_API_KEY }} \
150-
--github_api_key=${{ secrets.SCM_GITHUB_KEY }} \
151-
--issue_url=https://github.com/patched-codes/patchwork/issues/20 \
152-
--disable_telemetry
146+
# disabled because this currently takes too long
147+
# - name: Resolve issue
148+
# run: |
149+
# poetry run patchwork ResolveIssue --log debug \
150+
# --patched_api_key=${{ secrets.PATCHED_API_KEY }} \
151+
# --github_api_key=${{ secrets.SCM_GITHUB_KEY }} \
152+
# --issue_url=https://github.com/patched-codes/patchwork/issues/1039 \
153+
# --disable_telemetry
153154

154155
main-test:
155156
runs-on: ubuntu-latest

patchwork/common/client/llm/aio.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,21 @@
11
from __future__ import annotations
22

3+
import os
4+
35
from openai.types.chat import (
46
ChatCompletion,
57
ChatCompletionMessageParam,
8+
ChatCompletionToolChoiceOptionParam,
9+
ChatCompletionToolParam,
610
completion_create_params,
711
)
812
from typing_extensions import Dict, Iterable, List, Optional, Union
913

14+
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
15+
from patchwork.common.client.llm.google import GoogleLlmClient
16+
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
1017
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
18+
from patchwork.common.constants import DEFAULT_PATCH_URL
1119
from patchwork.logger import logger
1220

1321

@@ -29,10 +37,43 @@ def get_models(self) -> set[str]:
2937
def is_model_supported(self, model: str) -> bool:
3038
return any(client.is_model_supported(model) for client in self.__clients)
3139

32-
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
40+
def is_prompt_supported(
41+
self,
42+
messages: Iterable[ChatCompletionMessageParam],
43+
model: str,
44+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
45+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
46+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
47+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
48+
n: Optional[int] | NotGiven = NOT_GIVEN,
49+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
50+
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
51+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
52+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
53+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
54+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
55+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
56+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
57+
) -> int:
3358
for client in self.__clients:
3459
if client.is_model_supported(model):
35-
return client.is_prompt_supported(messages, model)
60+
return client.is_prompt_supported(
61+
messages=messages,
62+
model=model,
63+
frequency_penalty=frequency_penalty,
64+
logit_bias=logit_bias,
65+
logprobs=logprobs,
66+
max_tokens=max_tokens,
67+
n=n,
68+
presence_penalty=presence_penalty,
69+
response_format=response_format,
70+
stop=stop,
71+
temperature=temperature,
72+
tools=tools,
73+
tool_choice=tool_choice,
74+
top_logprobs=top_logprobs,
75+
top_p=top_p,
76+
)
3677
return -1
3778

3879
def truncate_messages(
@@ -56,6 +97,8 @@ def chat_completion(
5697
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
5798
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
5899
temperature: Optional[float] | NotGiven = NOT_GIVEN,
100+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
101+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
59102
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
60103
top_p: Optional[float] | NotGiven = NOT_GIVEN,
61104
) -> ChatCompletion:
@@ -74,6 +117,8 @@ def chat_completion(
74117
response_format,
75118
stop,
76119
temperature,
120+
tools,
121+
tool_choice,
77122
top_logprobs,
78123
top_p,
79124
)
@@ -82,3 +127,33 @@ def chat_completion(
82127
f"Model {model} is not supported by {client_names} clients. "
83128
f"Please ensure that the respective API keys are correct."
84129
)
130+
131+
@staticmethod
132+
def create_aio_client(inputs) -> "AioLlmClient" | None:
133+
clients = []
134+
135+
patched_key = inputs.get("patched_api_key")
136+
if patched_key is not None:
137+
client = OpenAiLlmClient(patched_key, DEFAULT_PATCH_URL)
138+
clients.append(client)
139+
140+
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
141+
if openai_key is not None:
142+
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
143+
client = OpenAiLlmClient(openai_key, **client_args)
144+
clients.append(client)
145+
146+
google_key = inputs.get("google_api_key")
147+
if google_key is not None:
148+
client = GoogleLlmClient(google_key)
149+
clients.append(client)
150+
151+
anthropic_key = inputs.get("anthropic_api_key")
152+
if anthropic_key is not None:
153+
client = AnthropicLlmClient(anthropic_key)
154+
clients.append(client)
155+
156+
if len(clients) == 0:
157+
return None
158+
159+
return AioLlmClient(*clients)

patchwork/common/client/llm/anthropic.py

Lines changed: 174 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from functools import lru_cache
66

77
from anthropic import Anthropic
8-
from anthropic.types import Message, TextBlockParam
8+
from anthropic.types import Message, MessageParam, TextBlockParam
99
from openai.types.chat import (
1010
ChatCompletion,
1111
ChatCompletionMessage,
1212
ChatCompletionMessageParam,
13+
ChatCompletionToolChoiceOptionParam,
14+
ChatCompletionToolParam,
1315
completion_create_params,
1416
)
1517
from openai.types.chat.chat_completion import Choice, CompletionUsage
@@ -81,23 +83,169 @@ def __get_model_limit(self, model: str) -> int:
8183
return 100_000 - safety_margin
8284
return 200_000 - safety_margin
8385

86+
def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) -> list[MessageParam]:
87+
new_messages = []
88+
for message in messages:
89+
if message.get("role") == "system":
90+
if system is NOT_GIVEN:
91+
system = list()
92+
system.append(TextBlockParam(text=message.get("content"), type="text"))
93+
elif message.get("role") == "tool":
94+
new_messages.append(
95+
dict(
96+
role="user",
97+
content=[
98+
dict(
99+
type="tool_result",
100+
tool_use_id=message.get("tool_call_id"),
101+
content=message.get("content"),
102+
)
103+
],
104+
)
105+
)
106+
elif message.get("role") == "assistant" and len(message.get("tool_calls", [])) > 0:
107+
tool_calls = message["tool_calls"]
108+
tool_calls_as_content = [
109+
dict(
110+
type="tool_use",
111+
id=tool_call["id"],
112+
name=tool_call["function"]["name"],
113+
input=json.loads(tool_call["function"]["arguments"]),
114+
)
115+
for tool_call in tool_calls
116+
]
117+
new_messages.append(
118+
dict(
119+
role="assistant",
120+
content=[
121+
*tool_calls_as_content,
122+
],
123+
)
124+
)
125+
else:
126+
new_messages.append(message)
127+
128+
return new_messages
129+
130+
def __adapt_chat_completion_request(
131+
self,
132+
messages: Iterable[ChatCompletionMessageParam],
133+
model: str,
134+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
135+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
136+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
137+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
138+
n: Optional[int] | NotGiven = NOT_GIVEN,
139+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
140+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
141+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
142+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
143+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
144+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
145+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
146+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
147+
):
148+
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
149+
adapted_messages = self.__adapt_input_messages(messages)
150+
default_max_token = 1000
151+
152+
if tool_choice is not NOT_GIVEN:
153+
# openai tool choice to anthropic tool choice mapping:
154+
# openai : none, auto, required , required
155+
# anthropic: NA , auto, any , tool
156+
if isinstance(tool_choice, str):
157+
if tool_choice == "required":
158+
tool_choice = dict(type="any")
159+
elif tool_choice == "none":
160+
tool_choice = NOT_GIVEN
161+
else:
162+
tool_choice = dict(type=tool_choice)
163+
else:
164+
tool_choice_type = tool_choice.get("type")
165+
if tool_choice_type == "required":
166+
if tool_choice.get("function") is not None:
167+
tool_choice["type"] = "tool"
168+
tool_choice["name"] = tool_choice["function"]["name"]
169+
else:
170+
tool_choice["type"] = "any"
171+
elif tool_choice_type == "none":
172+
tool_choice = NOT_GIVEN
173+
174+
input_kwargs = dict(
175+
messages=adapted_messages,
176+
system=system,
177+
max_tokens=default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens,
178+
model=model,
179+
stop_sequences=[stop] if isinstance(stop, str) else stop,
180+
temperature=temperature,
181+
tools=[tool.get("function") for tool in tools if tool.get("function") is not None],
182+
tool_choice=tool_choice,
183+
top_p=top_p,
184+
)
185+
186+
if response_format is not NOT_GIVEN and response_format.get("type") == "json_schema":
187+
input_kwargs["tool_choice"] = dict(type="tool", name="response_format")
188+
if input_kwargs.get("tools") is NOT_GIVEN:
189+
input_kwargs["tools"] = list()
190+
response_format_tool = dict(
191+
name="response_format",
192+
description="The response format to use",
193+
input_schema=response_format["json_schema"]["schema"],
194+
)
195+
input_kwargs["tools"] = [*input_kwargs["tools"], response_format_tool]
196+
197+
return NotGiven.remove_not_given(input_kwargs)
198+
84199
@lru_cache(maxsize=None)
85200
def get_models(self) -> set[str]:
86201
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
87202

88203
def is_model_supported(self, model: str) -> bool:
89204
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)
90205

91-
def is_prompt_supported(self, messages: Iterable[ChatCompletionMessageParam], model: str) -> int:
206+
def is_prompt_supported(
207+
self,
208+
messages: Iterable[ChatCompletionMessageParam],
209+
model: str,
210+
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
211+
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
212+
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
213+
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
214+
n: Optional[int] | NotGiven = NOT_GIVEN,
215+
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
216+
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
217+
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
218+
temperature: Optional[float] | NotGiven = NOT_GIVEN,
219+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
220+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
221+
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
222+
top_p: Optional[float] | NotGiven = NOT_GIVEN,
223+
) -> int:
92224
model_limit = self.__get_model_limit(model)
93-
token_count = 0
94-
for message in messages:
95-
message_token_count = self.client.count_tokens(message.get("content"))
96-
token_count = token_count + message_token_count
97-
if token_count > model_limit:
98-
return -1
99-
100-
return model_limit - token_count
225+
input_kwargs = self.__adapt_chat_completion_request(
226+
messages=messages,
227+
model=model,
228+
frequency_penalty=frequency_penalty,
229+
logit_bias=logit_bias,
230+
logprobs=logprobs,
231+
max_tokens=max_tokens,
232+
n=n,
233+
presence_penalty=presence_penalty,
234+
response_format=response_format,
235+
stop=stop,
236+
temperature=temperature,
237+
tools=tools,
238+
tool_choice=tool_choice,
239+
top_logprobs=top_logprobs,
240+
top_p=top_p,
241+
)
242+
count_token_input_kwargs = {
243+
k: v
244+
for k, v in input_kwargs.items()
245+
if k in {"messages", "model", "system", "tool_choice", "tools", "beta"}
246+
}
247+
message_token_count = self.client.beta.messages.count_tokens(**count_token_input_kwargs)
248+
return model_limit - message_token_count.input_tokens
101249

102250
def truncate_messages(
103251
self, messages: Iterable[ChatCompletionMessageParam], model: str
@@ -117,38 +265,28 @@ def chat_completion(
117265
response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
118266
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
119267
temperature: Optional[float] | NotGiven = NOT_GIVEN,
268+
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
269+
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
120270
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
121271
top_p: Optional[float] | NotGiven = NOT_GIVEN,
122272
) -> ChatCompletion:
123-
system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN
124-
other_messages = []
125-
for message in messages:
126-
if message.get("role") == "system":
127-
if system is NOT_GIVEN:
128-
system = list()
129-
system.append(TextBlockParam(text=message.get("content"), type="text"))
130-
else:
131-
other_messages.append(message)
132-
133-
default_max_token = 1000
134-
input_kwargs = dict(
135-
messages=other_messages,
136-
system=system,
137-
max_tokens=default_max_token if max_tokens is None or max_tokens is NOT_GIVEN else max_tokens,
273+
input_kwargs = self.__adapt_chat_completion_request(
274+
messages=messages,
138275
model=model,
139-
stop_sequences=[stop] if isinstance(stop, str) else stop,
276+
frequency_penalty=frequency_penalty,
277+
logit_bias=logit_bias,
278+
logprobs=logprobs,
279+
max_tokens=max_tokens,
280+
n=n,
281+
presence_penalty=presence_penalty,
282+
response_format=response_format,
283+
stop=stop,
140284
temperature=temperature,
285+
tools=tools,
286+
tool_choice=tool_choice,
287+
top_logprobs=top_logprobs,
141288
top_p=top_p,
142289
)
143-
if response_format is not NOT_GIVEN and response_format.get("type") == "json_schema":
144-
input_kwargs["tool_choice"] = dict(type="tool", name="response_format")
145-
input_kwargs["tools"] = [
146-
dict(
147-
name="response_format",
148-
description="The response format to use",
149-
input_schema=response_format["json_schema"]["schema"],
150-
)
151-
]
152290

153-
response = self.client.messages.create(**NotGiven.remove_not_given(input_kwargs))
291+
response = self.client.messages.create(**input_kwargs)
154292
return _anthropic_to_openai_response(model, response)

0 commit comments

Comments
 (0)