Skip to content

Commit 3937479

Browse files
committed
bump to pre release version
1 parent 29771da commit 3937479

File tree

13 files changed

+438
-439
lines changed

13 files changed

+438
-439
lines changed

patchwork/common/client/llm/google.py

Lines changed: 65 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
from __future__ import annotations
22

3-
import functools
43
import time
5-
6-
from google import generativeai
7-
from google.generativeai.types.content_types import (
8-
add_object_type,
9-
convert_to_nullable,
10-
strip_titles,
11-
unpack_defs,
4+
from functools import lru_cache
5+
from pathlib import Path
6+
7+
from google import genai
8+
from google.genai.types import (
9+
CountTokensConfig,
10+
File,
11+
GenerateContentConfig,
12+
GenerateContentResponse,
13+
Model,
1214
)
13-
from google.generativeai.types.generation_types import GenerateContentResponse
14-
from google.generativeai.types.model_types import Model
1515
from openai.types import CompletionUsage
1616
from openai.types.chat import (
1717
ChatCompletionMessage,
@@ -21,17 +21,13 @@
2121
completion_create_params,
2222
)
2323
from openai.types.chat.chat_completion import ChatCompletion, Choice
24-
from typing_extensions import Any, Dict, Iterable, List, Optional, Union
24+
from pydantic import BaseModel
25+
from typing_extensions import Any, Dict, Iterable, List, Optional, Type, Union
2526

2627
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
2728
from patchwork.common.client.llm.utils import json_schema_to_model
2829

2930

30-
@functools.lru_cache
31-
def _cached_list_model_from_google() -> list[Model]:
32-
return list(generativeai.list_models())
33-
34-
3531
class GoogleLlmClient(LlmClient):
3632
__SAFETY_SETTINGS = [
3733
dict(category="HARM_CATEGORY_HATE_SPEECH", threshold="BLOCK_NONE"),
@@ -43,20 +39,45 @@ class GoogleLlmClient(LlmClient):
4339

4440
def __init__(self, api_key: str):
4541
self.__api_key = api_key
46-
generativeai.configure(api_key=api_key)
42+
self.client = genai.Client(api_key=api_key)
43+
44+
@lru_cache(maxsize=1)
45+
def __get_models_info(self) -> list[Model]:
46+
return list(self.client.models.list())
4747

4848
def __get_model_limits(self, model: str) -> int:
49-
for model_info in _cached_list_model_from_google():
50-
if model_info.name == f"{self.__MODEL_PREFIX}{model}":
49+
for model_info in self.__get_models_info():
50+
if model_info.name == f"{self.__MODEL_PREFIX}{model}" and model_info.input_token_limit is not None:
5151
return model_info.input_token_limit
5252
return 1_000_000
5353

54+
@lru_cache
5455
def get_models(self) -> set[str]:
55-
return {model.name.removeprefix(self.__MODEL_PREFIX) for model in _cached_list_model_from_google()}
56+
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
5657

5758
def is_model_supported(self, model: str) -> bool:
5859
return model in self.get_models()
5960

61+
def __upload(self, file: Path | NotGiven) -> File | None:
62+
if file is NotGiven:
63+
return None
64+
65+
try:
66+
file_ref = self.client.files.get(file.name)
67+
if file_ref.error is None:
68+
return file_ref
69+
except Exception as e:
70+
pass
71+
72+
try:
73+
file_ref = self.client.files.upload(file=file)
74+
if file_ref.error is None:
75+
return file_ref
76+
except Exception as e:
77+
pass
78+
79+
return None
80+
6081
def is_prompt_supported(
6182
self,
6283
messages: Iterable[ChatCompletionMessageParam],
@@ -74,11 +95,23 @@ def is_prompt_supported(
7495
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
7596
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
7697
top_p: Optional[float] | NotGiven = NOT_GIVEN,
98+
file: Path | NotGiven = NOT_GIVEN,
7799
) -> int:
78100
system, chat = self.__openai_messages_to_google_messages(messages)
79-
gen_model = generativeai.GenerativeModel(model_name=model, system_instruction=system)
101+
102+
file_ref = self.__upload(file)
103+
if file_ref is not None:
104+
chat.append(file_ref)
105+
80106
try:
81-
token_count = gen_model.count_tokens(chat).total_tokens
107+
token_response = self.client.models.count_tokens(
108+
model=model,
109+
contents=chat,
110+
config=CountTokensConfig(
111+
system_instructions=system,
112+
),
113+
)
114+
token_count = token_response.total_tokens
82115
except Exception as e:
83116
return -1
84117
model_limit = self.__get_model_limits(model)
@@ -142,13 +175,15 @@ def chat_completion(
142175

143176
system_content, contents = self.__openai_messages_to_google_messages(messages)
144177

145-
model_client = generativeai.GenerativeModel(
146-
model_name=model,
147-
safety_settings=self.__SAFETY_SETTINGS,
148-
generation_config=NOT_GIVEN.remove_not_given(generation_dict),
149-
system_instruction=system_content,
178+
response = self.client.models.generate_content(
179+
model=model,
180+
contents=contents,
181+
config=GenerateContentConfig(
182+
system_instruction=system_content,
183+
safety_settings=self.__SAFETY_SETTINGS,
184+
**generation_dict,
185+
),
150186
)
151-
response = model_client.generate_content(contents=contents)
152187
return self.__google_response_to_openai_response(response, model)
153188

154189
@staticmethod
@@ -191,18 +226,9 @@ def __google_response_to_openai_response(google_response: GenerateContentRespons
191226
)
192227

193228
@staticmethod
194-
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> dict[str, Any] | None:
229+
def json_schema_to_google_schema(json_schema: dict[str, Any] | None) -> Type[BaseModel] | None:
195230
if json_schema is None:
196231
return None
197232

198233
model = json_schema_to_model(json_schema)
199-
parameters = model.model_json_schema()
200-
defs = parameters.pop("$defs", {})
201-
202-
for name, value in defs.items():
203-
unpack_defs(value, defs)
204-
unpack_defs(parameters, defs)
205-
convert_to_nullable(parameters)
206-
add_object_type(parameters)
207-
strip_titles(parameters)
208-
return parameters
234+
return model

patchwork/steps/CallLLM/CallLLM.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111
from rich.markup import escape
1212

1313
from patchwork.common.client.llm.aio import AioLlmClient
14-
from patchwork.common.client.llm.anthropic import AnthropicLlmClient
15-
from patchwork.common.client.llm.google import GoogleLlmClient
16-
from patchwork.common.client.llm.openai_ import OpenAiLlmClient
17-
from patchwork.common.constants import DEFAULT_PATCH_URL, TOKEN_URL
14+
from patchwork.common.constants import TOKEN_URL
1815
from patchwork.logger import logger
1916
from patchwork.step import Step, StepStatus
2017
from patchwork.steps.CallLLM.typed import CallLLMInputs, CallLLMOutputs
@@ -54,31 +51,9 @@ def __init__(self, inputs: dict):
5451
self.save_responses_to_file = inputs.get("save_responses_to_file", None)
5552
self.model = inputs.get("model", "gpt-4o-mini")
5653
self.allow_truncated = inputs.get("allow_truncated", False)
57-
58-
clients = []
59-
60-
patched_key = inputs.get("patched_api_key")
61-
if patched_key is not None:
62-
client = OpenAiLlmClient(patched_key, DEFAULT_PATCH_URL)
63-
clients.append(client)
64-
65-
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
66-
if openai_key is not None:
67-
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
68-
client = OpenAiLlmClient(openai_key, **client_args)
69-
clients.append(client)
70-
71-
google_key = inputs.get("google_api_key")
72-
if google_key is not None:
73-
client = GoogleLlmClient(google_key)
74-
clients.append(client)
75-
76-
anthropic_key = inputs.get("anthropic_api_key")
77-
if anthropic_key is not None:
78-
client = AnthropicLlmClient(anthropic_key)
79-
clients.append(client)
80-
81-
if len(clients) == 0:
54+
self.file = inputs.get("file", None)
55+
self.client = AioLlmClient.create_aio_client(inputs)
56+
if self.client is None:
8257
raise ValueError(
8358
f"Model API key not found.\n"
8459
f'Please login at: "{TOKEN_URL}",\n'
@@ -89,8 +64,6 @@ def __init__(self, inputs: dict):
8964
"If you are using an OpenAI API Key, please set `--openai_api_key=<token>`.\n"
9065
)
9166

92-
self.client = AioLlmClient(*clients)
93-
9467
def __persist_to_file(self, contents):
9568
# Convert relative path to absolute path
9669
file_path = os.path.abspath(self.save_responses_to_file)
@@ -143,6 +116,10 @@ def __call(self, prompts: list[list[dict]]) -> list[_InnerCallLLMResponse]:
143116
# Parse model arguments
144117
parsed_model_args = self.__parse_model_args()
145118

119+
kwargs = dict(parsed_model_args)
120+
if self.file is not None:
121+
kwargs["file"] = Path(self.file)
122+
146123
for prompt in prompts:
147124
is_input_accepted = self.client.is_prompt_supported(prompt, self.model) > 0
148125
if not is_input_accepted:

patchwork/steps/CallLLM/typed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class CallLLMInputs(TypedDict, total=False):
3535
google_api_key: Annotated[
3636
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
3737
]
38+
file: Annotated[str, StepTypeConfig(is_path=True)]
3839

3940

4041
class CallLLMOutputs(TypedDict):

patchwork/steps/LLM/LLM.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,13 @@
44
from patchwork.steps.ExtractModelResponse.ExtractModelResponse import (
55
ExtractModelResponse,
66
)
7-
from patchwork.steps.LLM.typed import LLMInputs
7+
from patchwork.steps.LLM.typed import LLMInputs, LLMOutputs
88
from patchwork.steps.PreparePrompt.PreparePrompt import PreparePrompt
99

1010

11-
class LLM(Step):
11+
class LLM(Step, input_class=LLMInputs, output_class=LLMOutputs):
1212
def __init__(self, inputs):
1313
super().__init__(inputs)
14-
missing_keys = LLMInputs.__required_keys__.difference(set(inputs.keys()))
15-
if len(missing_keys) > 0:
16-
raise ValueError(f'Missing required data: "{missing_keys}"')
17-
1814
self.inputs = inputs
1915

2016
def run(self) -> dict:

patchwork/steps/LLM/typed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class LLMInputs(__LLMInputsRequired, total=False):
4343
google_api_key: Annotated[
4444
str, StepTypeConfig(is_config=True, or_op=["patched_api_key", "openai_api_key", "anthropic_api_key"])
4545
]
46+
file: Annotated[str, StepTypeConfig(is_path=True)]
4647
# ExtractModelResponseInputs
4748
response_partitions: Annotated[Dict[str, List[str]], StepTypeConfig(is_config=True)]
4849

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import base64
4+
import os
5+
import quopri
6+
from datetime import datetime
7+
from pathlib import Path
8+
9+
from eml_parser import EmlParser
10+
from pydantic import BaseModel, Field
11+
12+
from patchwork.step import Step
13+
from patchwork.steps.ReadEmail.typed import ReadEmailInputs, ReadEmailOutputs
14+
15+
16+
class ParsedHeader(BaseModel):
17+
subject: str
18+
from_: str = Field(alias="from")
19+
to: list[str]
20+
date: datetime
21+
22+
23+
class ParsedBody(BaseModel):
24+
content: str
25+
content_type: str
26+
27+
28+
class AttachmentHeader(BaseModel):
29+
content_disposition: list[str] = Field(alias="content-disposition")
30+
content_transfer_encoding: list[str] = Field(alias="content-transfer-encoding")
31+
content_type: list[str] = Field(alias="content-type")
32+
33+
34+
class ParsedAttachment(BaseModel):
35+
filename: str
36+
raw: bytes
37+
content_header: AttachmentHeader
38+
39+
40+
class ParsedEmail(BaseModel):
41+
header: ParsedHeader
42+
body: list[ParsedBody]
43+
attachment: list[ParsedAttachment]
44+
45+
46+
class ReadEmail(Step, input_class=ReadEmailInputs, output_class=ReadEmailOutputs):
47+
def __init__(self, inputs: dict):
48+
super().__init__(inputs)
49+
self.file = inputs["eml_file_path"]
50+
self.base_path = inputs.get("base_path", os.getcwd())
51+
52+
def __decode(self, content_transfer_encoding: str, content: bytes) -> bytes:
53+
if content_transfer_encoding.lower() == "base64":
54+
return base64.b64decode(content)
55+
elif content_transfer_encoding.lower() == "quoted‑printable":
56+
return quopri.decodestring(content)
57+
58+
return content
59+
60+
def run(self) -> dict:
61+
ep = EmlParser(
62+
include_raw_body=True,
63+
include_attachment_data=True,
64+
)
65+
66+
email_data_dict = ep.decode_email(self.file)
67+
email_data = ParsedEmail.model_validate(email_data_dict)
68+
69+
rv = {
70+
"subject": email_data.header.subject,
71+
"datetime": email_data.header.date,
72+
"from": email_data.header.from_,
73+
"attachments": [],
74+
"body": "",
75+
}
76+
77+
base_path = Path(self.base_path)
78+
base_path.mkdir(parents=True, exist_ok=True)
79+
for attachment in email_data.attachment:
80+
file_path = base_path / attachment.filename
81+
with file_path.open("wb") as f:
82+
content = attachment.raw
83+
for content_transfer_encoding in attachment.content_header.content_transfer_encoding:
84+
content = self.__decode(content_transfer_encoding, content)
85+
f.write(content)
86+
rv["attachments"].append(dict(path=file_path, content=content.decode()))
87+
88+
for body in email_data.body:
89+
rv["body"] += body.content
90+
91+
return rv

patchwork/steps/ReadEmail/__init__.py

Whitespace-only changes.

patchwork/steps/ReadEmail/typed.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing_extensions import Annotated, List, TypedDict
2+
3+
from patchwork.common.utils.step_typing import StepTypeConfig
4+
5+
6+
class __ReadEmailRequiredInputs(TypedDict):
7+
eml_file_path: Annotated[str, StepTypeConfig(is_path=True)]
8+
9+
10+
class ReadEmailInputs(__ReadEmailRequiredInputs, total=False):
11+
base_path: Annotated[str, StepTypeConfig(is_path=True)]
12+
13+
14+
class Attachment(TypedDict):
15+
path: str
16+
content: str
17+
18+
19+
class ReadEmailOutputs(TypedDict):
20+
subject: str
21+
datetime: str
22+
from_: str # this is actually from instead of from_
23+
body: str
24+
attachments: List[Attachment]

0 commit comments

Comments
 (0)