Skip to content

Commit ce362bf

Browse files
committed
fix llm client by force
1 parent 7a7d4bc commit ce362bf

File tree

7 files changed

+50
-25
lines changed

7 files changed

+50
-25
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from pathlib import Path
45

56
from openai.types.chat import (
67
ChatCompletion,
@@ -54,10 +55,11 @@ def is_prompt_supported(
5455
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
5556
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
5657
top_p: Optional[float] | NotGiven = NOT_GIVEN,
58+
file: Path | NotGiven = NOT_GIVEN,
5759
) -> int:
5860
for client in self.__clients:
5961
if client.is_model_supported(model):
60-
return client.is_prompt_supported(
62+
inputs = dict(
6163
messages=messages,
6264
model=model,
6365
frequency_penalty=frequency_penalty,
@@ -74,6 +76,9 @@ def is_prompt_supported(
7476
top_logprobs=top_logprobs,
7577
top_p=top_p,
7678
)
79+
if file is not NotGiven:
80+
inputs["file"] = file
81+
return client.is_prompt_supported(**inputs)
7782
return -1
7883

7984
def truncate_messages(
@@ -101,27 +106,31 @@ def chat_completion(
101106
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
102107
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
103108
top_p: Optional[float] | NotGiven = NOT_GIVEN,
109+
file: Path | NotGiven = NOT_GIVEN,
104110
) -> ChatCompletion:
105111
for client in self.__clients:
106112
if client.is_model_supported(model):
107113
logger.debug(f"Using {client.__class__.__name__} for model {model}")
108-
return client.chat_completion(
109-
messages,
110-
model,
111-
frequency_penalty,
112-
logit_bias,
113-
logprobs,
114-
max_tokens,
115-
n,
116-
presence_penalty,
117-
response_format,
118-
stop,
119-
temperature,
120-
tools,
121-
tool_choice,
122-
top_logprobs,
123-
top_p,
114+
inputs = dict(
115+
messages=messages,
116+
model=model,
117+
frequency_penalty=frequency_penalty,
118+
logit_bias=logit_bias,
119+
logprobs=logprobs,
120+
max_tokens=max_tokens,
121+
n=n,
122+
presence_penalty=presence_penalty,
123+
response_format=response_format,
124+
stop=stop,
125+
temperature=temperature,
126+
tools=tools,
127+
tool_choice=tool_choice,
128+
top_logprobs=top_logprobs,
129+
top_p=top_p,
124130
)
131+
if file is not NotGiven:
132+
inputs["file"] = file
133+
return client.chat_completion(**inputs)
125134
client_names = [client.__class__.__name__ for client in self.__original_clients]
126135
raise ValueError(
127136
f"Model {model} is not supported by {client_names} clients. "

patchwork/common/client/llm/anthropic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import time
55
from functools import lru_cache
6+
from pathlib import Path
67

78
from anthropic import Anthropic
89
from anthropic.types import Message, MessageParam, TextBlockParam
@@ -224,6 +225,7 @@ def is_prompt_supported(
224225
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
225226
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
226227
top_p: Optional[float] | NotGiven = NOT_GIVEN,
228+
file: Path | NotGiven = NOT_GIVEN,
227229
) -> int:
228230
model_limit = self.__get_model_limit(model)
229231
input_kwargs = self.__adapt_chat_completion_request(
@@ -273,6 +275,7 @@ def chat_completion(
273275
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
274276
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
275277
top_p: Optional[float] | NotGiven = NOT_GIVEN,
278+
file: Path | NotGiven = NOT_GIVEN,
276279
) -> ChatCompletion:
277280
input_kwargs = self.__adapt_chat_completion_request(
278281
messages=messages,

patchwork/common/client/llm/google.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import hashlib
34
import time
45
from functools import lru_cache
56
from pathlib import Path
@@ -62,15 +63,16 @@ def __upload(self, file: Path | NotGiven) -> File | None:
6263
if file is NotGiven:
6364
return None
6465

66+
md5_name = hashlib.md5(file.read_bytes()).hexdigest()
6567
try:
66-
file_ref = self.client.files.get(file.name)
68+
file_ref = self.client.files.get(name=md5_name)
6769
if file_ref.error is None:
6870
return file_ref
6971
except Exception as e:
7072
pass
7173

7274
try:
73-
file_ref = self.client.files.upload(file=file)
75+
file_ref = self.client.files.upload(file=file, config=dict(name=md5_name))
7476
if file_ref.error is None:
7577
return file_ref
7678
except Exception as e:
@@ -97,16 +99,16 @@ def is_prompt_supported(
9799
top_p: Optional[float] | NotGiven = NOT_GIVEN,
98100
file: Path | NotGiven = NOT_GIVEN,
99101
) -> int:
100-
system, chat = self.__openai_messages_to_google_messages(messages)
102+
system, contents = self.__openai_messages_to_google_messages(messages)
101103

102104
file_ref = self.__upload(file)
103105
if file_ref is not None:
104-
chat.append(file_ref)
106+
contents.append(file_ref)
105107

106108
try:
107109
token_response = self.client.models.count_tokens(
108110
model=model,
109-
contents=chat,
111+
contents=contents,
110112
config=CountTokensConfig(
111113
system_instruction=system,
112114
),
@@ -155,6 +157,7 @@ def chat_completion(
155157
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
156158
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
157159
top_p: Optional[float] | NotGiven = NOT_GIVEN,
160+
file: Path | NotGiven = NOT_GIVEN,
158161
) -> ChatCompletion:
159162
generation_dict = dict(
160163
stop_sequences=[stop] if isinstance(stop, str) else stop,
@@ -174,6 +177,9 @@ def chat_completion(
174177
)
175178

176179
system_content, contents = self.__openai_messages_to_google_messages(messages)
180+
file_ref = self.__upload(file)
181+
if file_ref is not None:
182+
contents.append(file_ref)
177183

178184
response = self.client.models.generate_content(
179185
model=model,

patchwork/common/client/llm/openai_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
from pathlib import Path
45

56
import tiktoken
67
from openai import OpenAI
@@ -82,6 +83,7 @@ def is_prompt_supported(
8283
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
8384
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
8485
top_p: Optional[float] | NotGiven = NOT_GIVEN,
86+
file: Path | NotGiven = NOT_GIVEN,
8587
) -> int:
8688
# might not implement model endpoint
8789
if self.__is_not_openai_url():
@@ -125,6 +127,7 @@ def chat_completion(
125127
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
126128
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
127129
top_p: Optional[float] | NotGiven = NOT_GIVEN,
130+
file: Path | NotGiven = NOT_GIVEN,
128131
) -> ChatCompletion:
129132
input_kwargs = dict(
130133
messages=messages,

patchwork/common/client/llm/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
4+
35
from openai.types.chat import (
46
ChatCompletion,
57
ChatCompletionMessageParam,
@@ -51,6 +53,7 @@ def is_prompt_supported(
5153
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
5254
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
5355
top_p: Optional[float] | NotGiven = NOT_GIVEN,
56+
file: Path | NotGiven = NOT_GIVEN,
5457
) -> int:
5558
...
5659

@@ -135,5 +138,6 @@ def chat_completion(
135138
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
136139
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
137140
top_p: Optional[float] | NotGiven = NOT_GIVEN,
141+
file: Path | NotGiven = NOT_GIVEN,
138142
) -> ChatCompletion:
139143
...

patchwork/steps/CallLLM/CallLLM.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,14 @@ def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
121121
kwargs["file"] = Path(self.file)
122122

123123
for prompt in prompts:
124-
is_input_accepted = self.client.is_prompt_supported(prompt, self.model) > 0
124+
is_input_accepted = self.client.is_prompt_supported(model=self.model, messages=prompt, **kwargs) > 0
125125
if not is_input_accepted:
126126
self.set_status(StepStatus.WARNING, "Input token limit exceeded.")
127127
prompt = self.client.truncate_messages(prompt, self.model)
128128

129129
logger.trace(f"Message sent: \n{escape(indent(pformat(prompt), ' '))}")
130130
try:
131-
completion = self.client.chat_completion(model=self.model, messages=prompt, **parsed_model_args)
131+
completion = self.client.chat_completion(model=self.model, messages=prompt, **kwargs)
132132
except Exception as e:
133133
logger.error(e)
134134
completion = None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "patchwork-cli"
3-
version = "0.0.99.dev3"
3+
version = "0.0.99.dev4"
44
description = ""
55
authors = ["patched.codes"]
66
license = "AGPL"

0 commit comments

Comments
 (0)