Skip to content

Commit 47fe2b3

Browse files
committed
Merge remote-tracking branch 'origin/main' into log-analysis
2 parents 9746b8b + 4cc8d3e commit 47fe2b3

File tree

21 files changed

+300
-116
lines changed

21 files changed

+300
-116
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from pathlib import Path
45

56
from openai.types.chat import (
67
ChatCompletion,
@@ -116,10 +117,11 @@ def is_prompt_supported(
116117
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
117118
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
118119
top_p: Optional[float] | NotGiven = NOT_GIVEN,
120+
file: Path | NotGiven = NOT_GIVEN,
119121
) -> int:
120122
for client in self.__clients:
121123
if client.is_model_supported(model):
122-
return client.is_prompt_supported(
124+
inputs = dict(
123125
messages=messages,
124126
model=model,
125127
frequency_penalty=frequency_penalty,
@@ -136,6 +138,9 @@ def is_prompt_supported(
136138
top_logprobs=top_logprobs,
137139
top_p=top_p,
138140
)
141+
if file is not NotGiven:
142+
inputs["file"] = file
143+
return client.is_prompt_supported(**inputs)
139144
return -1
140145

141146
def truncate_messages(
@@ -163,27 +168,31 @@ def chat_completion(
163168
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
164169
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
165170
top_p: Optional[float] | NotGiven = NOT_GIVEN,
171+
file: Path | NotGiven = NOT_GIVEN,
166172
) -> ChatCompletion:
167173
for client in self.__clients:
168174
if client.is_model_supported(model):
169175
logger.debug(f"Using {client.__class__.__name__} for model {model}")
170-
return client.chat_completion(
171-
messages,
172-
model,
173-
frequency_penalty,
174-
logit_bias,
175-
logprobs,
176-
max_tokens,
177-
n,
178-
presence_penalty,
179-
response_format,
180-
stop,
181-
temperature,
182-
tools,
183-
tool_choice,
184-
top_logprobs,
185-
top_p,
176+
inputs = dict(
177+
messages=messages,
178+
model=model,
179+
frequency_penalty=frequency_penalty,
180+
logit_bias=logit_bias,
181+
logprobs=logprobs,
182+
max_tokens=max_tokens,
183+
n=n,
184+
presence_penalty=presence_penalty,
185+
response_format=response_format,
186+
stop=stop,
187+
temperature=temperature,
188+
tools=tools,
189+
tool_choice=tool_choice,
190+
top_logprobs=top_logprobs,
191+
top_p=top_p,
186192
)
193+
if file is not NotGiven:
194+
inputs["file"] = file
195+
return client.chat_completion(**inputs)
187196
client_names = [client.__class__.__name__ for client in self.__original_clients]
188197
raise ValueError(
189198
f"Model {model} is not supported by {client_names} clients. "

patchwork/common/client/llm/anthropic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import time
55
from functools import cached_property, lru_cache
6+
from pathlib import Path
67

78
from anthropic import Anthropic
89
from anthropic.types import Message, MessageParam, TextBlockParam
@@ -268,6 +269,7 @@ def is_prompt_supported(
268269
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
269270
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
270271
top_p: Optional[float] | NotGiven = NOT_GIVEN,
272+
file: Path | NotGiven = NOT_GIVEN,
271273
) -> int:
272274
model_limit = self.__get_model_limit(model)
273275
input_kwargs = self.__adapt_chat_completion_request(
@@ -317,6 +319,7 @@ def chat_completion(
317319
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
318320
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
319321
top_p: Optional[float] | NotGiven = NOT_GIVEN,
322+
file: Path | NotGiven = NOT_GIVEN,
320323
) -> ChatCompletion:
321324
input_kwargs = self.__adapt_chat_completion_request(
322325
messages=messages,

patchwork/common/client/llm/google.py

Lines changed: 83 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from __future__ import annotations
22

3-
import functools
43
import time
5-
6-
from google import generativeai
7-
from google.generativeai.types.content_types import (
8-
add_object_type,
9-
convert_to_nullable,
10-
strip_titles,
11-
unpack_defs,
4+
from functools import lru_cache
5+
from pathlib import Path
6+
7+
import magic
8+
from google import genai
9+
from google.genai import types
10+
from google.genai.types import (
11+
CountTokensConfig,
12+
File,
13+
GenerateContentConfig,
14+
GenerateContentResponse,
15+
Model,
16+
Part,
1217
)
13-
from google.generativeai.types.generation_types import GenerateContentResponse
14-
from google.generativeai.types.model_types import Model
1518
from openai.types import CompletionUsage
1619
from openai.types.chat import (
1720
ChatCompletionMessage,
@@ -27,16 +30,12 @@
2730
from pydantic_ai.settings import ModelSettings
2831
from pydantic_ai.usage import Usage
2932
from typing_extensions import Any, AsyncIterator, Dict, Iterable, List, Optional, Union
33+
from pydantic import BaseModel
3034

3135
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
3236
from patchwork.common.client.llm.utils import json_schema_to_model
3337

3438

35-
@functools.lru_cache
36-
def _cached_list_model_from_google() -> list[Model]:
37-
return list(generativeai.list_models())
38-
39-
4039
class GoogleLlmClient(LlmClient):
4140
__SAFETY_SETTINGS = [
4241
dict(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
@@ -48,7 +47,11 @@ class GoogleLlmClient(LlmClient):
4847

4948
def __init__(self, api_key: str):
5049
self.__api_key = api_key
51-
generativeai.configure(api_key=api_key)
50+
self.client = genai.Client(api_key=api_key)
51+
52+
@lru_cache(maxsize=1)
53+
def __get_models_info(self) -> list[Model]:
54+
return list(self.client.models.list())
5255

5356
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
5457
if model_settings is None:
@@ -86,17 +89,47 @@ def system(self) -> str:
8689
return "google-gla"
8790

8891
def __get_model_limits(self, model: str) -> int:
89-
for model_info in _cached_list_model_from_google():
90-
if model_info.name == f"{self.__MODEL_PREFIX}{model}":
92+
for model_info in self.__get_models_info():
93+
if model_info.name == f"{self.__MODEL_PREFIX}{model}" and model_info.input_token_limit is not None:
9194
return model_info.input_token_limit
9295
return 1_000_000
9396

97+
@lru_cache
9498
def get_models(self) -> set[str]:
95-
return {model.name.removeprefix(self.__MODEL_PREFIX) for model in _cached_list_model_from_google()}
99+
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
96100

97101
def is_model_supported(self, model: str) -> bool:
98102
return model in self.get_models()
99103

104+
def __upload(self, file: Path | NotGiven) -> Part | File | None:
105+
if file is NotGiven:
106+
return None
107+
108+
file_bytes = file.read_bytes()
109+
110+
try:
111+
mime_type = magic.Magic(mime=True, uncompress=True).from_buffer(file_bytes)
112+
return types.Part.from_bytes(data=file_bytes, mime_type=mime_type)
113+
except Exception as e:
114+
pass
115+
116+
cleaned_name = "".join([char for char in file.name.lower() if char.isalnum()])
117+
try:
118+
file_ref = self.client.files.get(name=cleaned_name)
119+
if file_ref.error is None:
120+
return file_ref
121+
except Exception as e:
122+
pass
123+
124+
try:
125+
file_ref = self.client.files.upload(file=file, config=dict(name=cleaned_name))
126+
if file_ref.error is None:
127+
return file_ref
128+
except Exception as e:
129+
pass
130+
131+
return None
132+
100133
def is_prompt_supported(
101134
self,
102135
messages: Iterable[ChatCompletionMessageParam],
@@ -114,11 +147,23 @@ def is_prompt_supported(
114147
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
115148
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
116149
top_p: Optional[float] | NotGiven = NOT_GIVEN,
150+
file: Path | NotGiven = NOT_GIVEN,
117151
) -> int:
118-
system, chat = self.__openai_messages_to_google_messages(messages)
119-
gen_model = generativeai.GenerativeModel(model_name=model, system_instruction=system)
152+
system, contents = self.__openai_messages_to_google_messages(messages)
153+
154+
file_ref = self.__upload(file)
155+
if file_ref is not None:
156+
contents.append(file_ref)
157+
120158
try:
121-
token_count = gen_model.count_tokens(chat).total_tokens
159+
token_response = self.client.models.count_tokens(
160+
model=model,
161+
contents=contents,
162+
config=CountTokensConfig(
163+
system_instruction=system,
164+
),
165+
)
166+
token_count = token_response.total_tokens
122167
except Exception as e:
123168
return -1
124169
model_limit = self.__get_model_limits(model)
@@ -162,6 +207,7 @@ def chat_completion(
162207
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
163208
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
164209
top_p: Optional[float] | NotGiven = NOT_GIVEN,
210+
file: Path | NotGiven = NOT_GIVEN,
165211
) -> ChatCompletion:
166212
generation_dict = dict(
167213
stop_sequences=[stop] if isinstance(stop, str) else stop,
@@ -181,20 +227,25 @@ def chat_completion(
181227
)
182228

183229
system_content, contents = self.__openai_messages_to_google_messages(messages)
230+
file_ref = self.__upload(file)
231+
if file_ref is not None:
232+
contents.append(file_ref)
184233

185-
model_client = generativeai.GenerativeModel(
186-
model_name=model,
187-
safety_settings=self.__SAFETY_SETTINGS,
188-
generation_config=NOT_GIVEN.remove_not_given(generation_dict),
189-
system_instruction=system_content,
234+
response = self.client.models.generate_content(
235+
model=model,
236+
contents=contents,
237+
config=GenerateContentConfig(
238+
system_instruction=system_content,
239+
safety_settings=self.__SAFETY_SETTINGS,
240+
**NotGiven.remove_not_given(generation_dict),
241+
),
190242
)
191-
response = model_client.generate_content(contents=contents)
192243
return self.__google_response_to_openai_response(response, model)
193244

194245
@staticmethod
195246
def __google_response_to_openai_response(google_response: GenerateContentResponse, model: str) -> ChatCompletion:
196247
choices = []
197-
for candidate in google_response.candidates:
248+
for index, candidate in enumerate(google_response.candidates):
198249
# note that instead of system, from openai, its model, from google.
199250
parts = [part.text or part.inline_data for part in candidate.content.parts]
200251

@@ -207,7 +258,7 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
207258

208259
choice = Choice(
209260
finish_reason=finish_reason_map.get(candidate.finish_reason, "stop"),
210-
index=candidate.index,
261+
index=index,
211262
message=ChatCompletionMessage(
212263
content="\n".join(parts),
213264
role="assistant",
@@ -231,18 +282,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
231282
)
232283

233284
@staticmethod
234-
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> dict[str, Any] | None:
285+
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> Type[BaseModel] | None:
235286
if json_schema is None:
236287
return None
237288

238289
model = json_schema_to_model(json_schema)
239-
parameters = model.model_json_schema()
240-
defs = parameters.pop("$defs", {})
241-
242-
for name, value in defs.items():
243-
unpack_defs(value, defs)
244-
unpack_defs(parameters, defs)
245-
convert_to_nullable(parameters)
246-
add_object_type(parameters)
247-
strip_titles(parameters)
248-
return parameters
290+
return model

patchwork/common/client/llm/openai_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
from pathlib import Path
45
from functools import cached_property
56

67
import tiktoken
@@ -127,6 +128,7 @@ def is_prompt_supported(
127128
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
128129
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
129130
top_p: Optional[float] | NotGiven = NOT_GIVEN,
131+
file: Path | NotGiven = NOT_GIVEN,
130132
) -> int:
131133
# might not implement model endpoint
132134
if self.__is_not_openai_url():
@@ -170,6 +172,7 @@ def chat_completion(
170172
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
171173
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
172174
top_p: Optional[float] | NotGiven = NOT_GIVEN,
175+
file: Path | NotGiven = NOT_GIVEN,
173176
) -> ChatCompletion:
174177
input_kwargs = dict(
175178
messages=messages,

patchwork/common/client/llm/protocol.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import Any, Dict, List
4+
from pathlib import Path
55

66
from openai.types.chat import (
77
ChatCompletion,
@@ -58,6 +58,7 @@ def is_prompt_supported(
5858
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
5959
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
6060
top_p: Optional[float] | NotGiven = NOT_GIVEN,
61+
file: Path | NotGiven = NOT_GIVEN,
6162
) -> int:
6263
...
6364

@@ -144,5 +145,6 @@ def chat_completion(
144145
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
145146
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
146147
top_p: Optional[float] | NotGiven = NOT_GIVEN,
148+
file: Path | NotGiven = NOT_GIVEN,
147149
) -> ChatCompletion:
148150
...

0 commit comments

Comments
 (0)