Skip to content

Commit 4cc8d3e

Browse files
authored
ReadEmail Step (#1299)
* bump to pre release version * remove attachment content from outputs * Add message-id to email steps * make most email properties optional * fix gemini issue * fix llm client by force * undo scm url requirement * bump * Update send email * fix test * better way to pass files * finalise
1 parent 29771da commit 4cc8d3e

File tree

22 files changed

+599
-520
lines changed

22 files changed

+599
-520
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import os
4+
from pathlib import Path
45

56
from openai.types.chat import (
67
ChatCompletion,
@@ -54,10 +55,11 @@ def is_prompt_supported(
5455
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
5556
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
5657
top_p: Optional[float] | NotGiven = NOT_GIVEN,
58+
file: Path | NotGiven = NOT_GIVEN,
5759
) -> int:
5860
for client in self.__clients:
5961
if client.is_model_supported(model):
60-
return client.is_prompt_supported(
62+
inputs = dict(
6163
messages=messages,
6264
model=model,
6365
frequency_penalty=frequency_penalty,
@@ -74,6 +76,9 @@ def is_prompt_supported(
7476
top_logprobs=top_logprobs,
7577
top_p=top_p,
7678
)
79+
if file is not NotGiven:
80+
inputs["file"] = file
81+
return client.is_prompt_supported(**inputs)
7782
return -1
7883

7984
def truncate_messages(
@@ -101,27 +106,31 @@ def chat_completion(
101106
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
102107
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
103108
top_p: Optional[float] | NotGiven = NOT_GIVEN,
109+
file: Path | NotGiven = NOT_GIVEN,
104110
) -> ChatCompletion:
105111
for client in self.__clients:
106112
if client.is_model_supported(model):
107113
logger.debug(f"Using {client.__class__.__name__} for model {model}")
108-
return client.chat_completion(
109-
messages,
110-
model,
111-
frequency_penalty,
112-
logit_bias,
113-
logprobs,
114-
max_tokens,
115-
n,
116-
presence_penalty,
117-
response_format,
118-
stop,
119-
temperature,
120-
tools,
121-
tool_choice,
122-
top_logprobs,
123-
top_p,
114+
inputs = dict(
115+
messages=messages,
116+
model=model,
117+
frequency_penalty=frequency_penalty,
118+
logit_bias=logit_bias,
119+
logprobs=logprobs,
120+
max_tokens=max_tokens,
121+
n=n,
122+
presence_penalty=presence_penalty,
123+
response_format=response_format,
124+
stop=stop,
125+
temperature=temperature,
126+
tools=tools,
127+
tool_choice=tool_choice,
128+
top_logprobs=top_logprobs,
129+
top_p=top_p,
124130
)
131+
if file is not NotGiven:
132+
inputs["file"] = file
133+
return client.chat_completion(**inputs)
125134
client_names = [client.__class__.__name__ for client in self.__original_clients]
126135
raise ValueError(
127136
f"Model {model} is not supported by {client_names} clients. "

patchwork/common/client/llm/anthropic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import time
55
from functools import lru_cache
6+
from pathlib import Path
67

78
from anthropic import Anthropic
89
from anthropic.types import Message, MessageParam, TextBlockParam
@@ -224,6 +225,7 @@ def is_prompt_supported(
224225
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
225226
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
226227
top_p: Optional[float] | NotGiven = NOT_GIVEN,
228+
file: Path | NotGiven = NOT_GIVEN,
227229
) -> int:
228230
model_limit = self.__get_model_limit(model)
229231
input_kwargs = self.__adapt_chat_completion_request(
@@ -273,6 +275,7 @@ def chat_completion(
273275
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
274276
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
275277
top_p: Optional[float] | NotGiven = NOT_GIVEN,
278+
file: Path | NotGiven = NOT_GIVEN,
276279
) -> ChatCompletion:
277280
input_kwargs = self.__adapt_chat_completion_request(
278281
messages=messages,

patchwork/common/client/llm/google.py

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
from __future__ import annotations
22

3-
import functools
43
import time
5-
6-
from google import generativeai
7-
from google.generativeai.types.content_types import (
8-
add_object_type,
9-
convert_to_nullable,
10-
strip_titles,
11-
unpack_defs,
4+
from functools import lru_cache
5+
from pathlib import Path
6+
7+
import magic
8+
from google import genai
9+
from google.genai import types
10+
from google.genai.types import (
11+
CountTokensConfig,
12+
File,
13+
GenerateContentConfig,
14+
GenerateContentResponse,
15+
Model,
16+
Part,
1217
)
13-
from google.generativeai.types.generation_types import GenerateContentResponse
14-
from google.generativeai.types.model_types import Model
1518
from openai.types import CompletionUsage
1619
from openai.types.chat import (
1720
ChatCompletionMessage,
@@ -21,17 +24,13 @@
2124
completion_create_params,
2225
)
2326
from openai.types.chat.chat_completion import ChatCompletion, Choice
24-
from typing_extensions import Any, Dict, Iterable, List, Optional, Union
27+
from pydantic import BaseModel
28+
from typing_extensions import Any, Dict, Iterable, List, Optional, Type, Union
2529

2630
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2731
from patchwork.common.client.llm.utils import json_schema_to_model
2832

2933

30-
@functools.lru_cache
31-
def _cached_list_model_from_google() -> list[Model]:
32-
return list(generativeai.list_models())
33-
34-
3534
class GoogleLlmClient(LlmClient):
3635
__SAFETY_SETTINGS = [
3736
dict(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
@@ -43,20 +42,54 @@ class GoogleLlmClient(LlmClient):
4342

4443
def __init__(self, api_key: str):
4544
self.__api_key = api_key
46-
generativeai.configure(api_key=api_key)
45+
self.client = genai.Client(api_key=api_key)
46+
47+
@lru_cache(maxsize=1)
48+
def __get_models_info(self) -> list[Model]:
49+
return list(self.client.models.list())
4750

4851
def __get_model_limits(self, model: str) -> int:
49-
for model_info in _cached_list_model_from_google():
50-
if model_info.name == f"{self.__MODEL_PREFIX}{model}":
52+
for model_info in self.__get_models_info():
53+
if model_info.name == f"{self.__MODEL_PREFIX}{model}" and model_info.input_token_limit is not None:
5154
return model_info.input_token_limit
5255
return 1_000_000
5356

57+
@lru_cache
5458
def get_models(self) -> set[str]:
55-
return {model.name.removeprefix(self.__MODEL_PREFIX) for model in _cached_list_model_from_google()}
59+
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
5660

5761
def is_model_supported(self, model: str) -> bool:
5862
return model in self.get_models()
5963

64+
def __upload(self, file: Path | NotGiven) -> Part | File | None:
65+
if file is NotGiven:
66+
return None
67+
68+
file_bytes = file.read_bytes()
69+
70+
try:
71+
mime_type = magic.Magic(mime=True, uncompress=True).from_buffer(file_bytes)
72+
return types.Part.from_bytes(data=file_bytes, mime_type=mime_type)
73+
except Exception as e:
74+
pass
75+
76+
cleaned_name = "".join([char for char in file.name.lower() if char.isalnum()])
77+
try:
78+
file_ref = self.client.files.get(name=cleaned_name)
79+
if file_ref.error is None:
80+
return file_ref
81+
except Exception as e:
82+
pass
83+
84+
try:
85+
file_ref = self.client.files.upload(file=file, config=dict(name=cleaned_name))
86+
if file_ref.error is None:
87+
return file_ref
88+
except Exception as e:
89+
pass
90+
91+
return None
92+
6093
def is_prompt_supported(
6194
self,
6295
messages: Iterable[ChatCompletionMessageParam],
@@ -74,11 +107,23 @@ def is_prompt_supported(
74107
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
75108
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
76109
top_p: Optional[float] | NotGiven = NOT_GIVEN,
110+
file: Path | NotGiven = NOT_GIVEN,
77111
) -> int:
78-
system, chat = self.__openai_messages_to_google_messages(messages)
79-
gen_model = generativeai.GenerativeModel(model_name=model, system_instruction=system)
112+
system, contents = self.__openai_messages_to_google_messages(messages)
113+
114+
file_ref = self.__upload(file)
115+
if file_ref is not None:
116+
contents.append(file_ref)
117+
80118
try:
81-
token_count = gen_model.count_tokens(chat).total_tokens
119+
token_response = self.client.models.count_tokens(
120+
model=model,
121+
contents=contents,
122+
config=CountTokensConfig(
123+
system_instruction=system,
124+
),
125+
)
126+
token_count = token_response.total_tokens
82127
except Exception as e:
83128
return -1
84129
model_limit = self.__get_model_limits(model)
@@ -122,6 +167,7 @@ def chat_completion(
122167
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
123168
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
124169
top_p: Optional[float] | NotGiven = NOT_GIVEN,
170+
file: Path | NotGiven = NOT_GIVEN,
125171
) -> ChatCompletion:
126172
generation_dict = dict(
127173
stop_sequences=[stop] if isinstance(stop, str) else stop,
@@ -141,20 +187,25 @@ def chat_completion(
141187
)
142188

143189
system_content, contents = self.__openai_messages_to_google_messages(messages)
190+
file_ref = self.__upload(file)
191+
if file_ref is not None:
192+
contents.append(file_ref)
144193

145-
model_client = generativeai.GenerativeModel(
146-
model_name=model,
147-
safety_settings=self.__SAFETY_SETTINGS,
148-
generation_config=NOT_GIVEN.remove_not_given(generation_dict),
149-
system_instruction=system_content,
194+
response = self.client.models.generate_content(
195+
model=model,
196+
contents=contents,
197+
config=GenerateContentConfig(
198+
system_instruction=system_content,
199+
safety_settings=self.__SAFETY_SETTINGS,
200+
**NotGiven.remove_not_given(generation_dict),
201+
),
150202
)
151-
response = model_client.generate_content(contents=contents)
152203
return self.__google_response_to_openai_response(response, model)
153204

154205
@staticmethod
155206
def __google_response_to_openai_response(google_response: GenerateContentResponse, model: str) -> ChatCompletion:
156207
choices = []
157-
for candidate in google_response.candidates:
208+
for index, candidate in enumerate(google_response.candidates):
158209
# note that instead of system, from openai, its model, from google.
159210
parts = [part.text or part.inline_data for part in candidate.content.parts]
160211

@@ -167,7 +218,7 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
167218

168219
choice = Choice(
169220
finish_reason=finish_reason_map.get(candidate.finish_reason, "stop"),
170-
index=candidate.index,
221+
index=index,
171222
message=ChatCompletionMessage(
172223
content="\n".join(parts),
173224
role="assistant",
@@ -191,18 +242,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
191242
)
192243

193244
@staticmethod
194-
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> dict[str, Any] | None:
245+
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> Type[BaseModel] | None:
195246
if json_schema is None:
196247
return None
197248

198249
model = json_schema_to_model(json_schema)
199-
parameters = model.model_json_schema()
200-
defs = parameters.pop("$defs", {})
201-
202-
for name, value in defs.items():
203-
unpack_defs(value, defs)
204-
unpack_defs(parameters, defs)
205-
convert_to_nullable(parameters)
206-
add_object_type(parameters)
207-
strip_titles(parameters)
208-
return parameters
250+
return model

patchwork/common/client/llm/openai_.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import functools
4+
from pathlib import Path
45

56
import tiktoken
67
from openai import OpenAI
@@ -82,6 +83,7 @@ def is_prompt_supported(
8283
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
8384
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
8485
top_p: Optional[float] | NotGiven = NOT_GIVEN,
86+
file: Path | NotGiven = NOT_GIVEN,
8587
) -> int:
8688
# might not implement model endpoint
8789
if self.__is_not_openai_url():
@@ -125,6 +127,7 @@ def chat_completion(
125127
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
126128
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
127129
top_p: Optional[float] | NotGiven = NOT_GIVEN,
130+
file: Path | NotGiven = NOT_GIVEN,
128131
) -> ChatCompletion:
129132
input_kwargs = dict(
130133
messages=messages,

patchwork/common/client/llm/protocol.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
4+
35
from openai.types.chat import (
46
ChatCompletion,
57
ChatCompletionMessageParam,
@@ -51,6 +53,7 @@ def is_prompt_supported(
5153
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
5254
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
5355
top_p: Optional[float] | NotGiven = NOT_GIVEN,
56+
file: Path | NotGiven = NOT_GIVEN,
5457
) -> int:
5558
...
5659

@@ -135,5 +138,6 @@ def chat_completion(
135138
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
136139
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
137140
top_p: Optional[float] | NotGiven = NOT_GIVEN,
141+
file: Path | NotGiven = NOT_GIVEN,
138142
) -> ChatCompletion:
139143
...

0 commit comments

Comments
 (0)