Skip to content

Commit 72fa88d

Browse files
committed
update
1 parent 0384115 commit 72fa88d

File tree

6 files changed

+402
-136
lines changed

6 files changed

+402
-136
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def __init__(self, *clients: LlmClient):
3535
self.__clients.append(client)
3636
except Exception as e:
3737
logger.error(f"{client.__class__.__name__} Failed with exception: {e}")
38-
pass
3938

4039
def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
4140
if model_settings is None:
@@ -205,8 +204,6 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
205204
clients = []
206205

207206
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
208-
if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
209-
client_args["is_gcp"] = True
210207

211208
patched_key = inputs.get("patched_api_key")
212209
if patched_key is not None:
@@ -219,8 +216,9 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
219216
clients.append(client)
220217

221218
google_key = inputs.get("google_api_key")
222-
if google_key is not None or "is_gcp" in client_args.keys():
223-
client = GoogleLlmClient(api_key=google_key, is_gcp=bool(client_args.get("is_gcp", False)))
219+
is_gcp = bool(client_args.get("is_gcp") or os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") or False)
220+
if google_key is not None or is_gcp:
221+
client = GoogleLlmClient(api_key=google_key, is_gcp=is_gcp)
224222
clients.append(client)
225223

226224
anthropic_key = inputs.get("anthropic_api_key")

patchwork/common/client/llm/anthropic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import time
5-
from functools import cached_property, lru_cache
5+
from functools import cached_property
66
from pathlib import Path
77

88
from anthropic import Anthropic

patchwork/common/client/llm/google_.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

3+
import os
34
import time
4-
from functools import lru_cache
5+
from functools import lru_cache, partial
56
from pathlib import Path
67

78
import magic
9+
import vertexai
810
from google import genai
9-
from google.auth.credentials import Credentials
11+
from google.auth.exceptions import GoogleAuthError
1012
from google.genai import types
13+
from google.genai.errors import APIError
1114
from google.genai.types import (
1215
CountTokensConfig,
1316
File,
@@ -42,9 +45,11 @@
4245
Type,
4346
Union,
4447
)
48+
from vertexai.generative_models import GenerativeModel, SafetySetting
4549

4650
from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
4751
from patchwork.common.client.llm.utils import json_schema_to_model
52+
from patchwork.logger import logger
4853

4954

5055
class GoogleLlmClient(LlmClient):
@@ -53,6 +58,28 @@ class GoogleLlmClient(LlmClient):
5358
dict(category="HARM_CATEGORY_SEXUALLY_EXPLICIT", threshold="BLOCK_NONE"),
5459
dict(category="HARM_CATEGORY_DANGEROUS_CONTENT", threshold="BLOCK_NONE"),
5560
dict(category="HARM_CATEGORY_HARASSMENT", threshold="BLOCK_NONE"),
61+
dict(category="HARM_CATEGORY_CIVIC_INTEGRITY", threshold="BLOCK_NONE"),
62+
]
63+
__VERTEX_SAFETY_SETTINGS = [
64+
SafetySetting(
65+
category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
66+
threshold=SafetySetting.HarmBlockThreshold.OFF,
67+
),
68+
SafetySetting(
69+
category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
70+
threshold=SafetySetting.HarmBlockThreshold.OFF,
71+
),
72+
SafetySetting(
73+
category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
74+
threshold=SafetySetting.HarmBlockThreshold.OFF,
75+
),
76+
SafetySetting(
77+
category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=SafetySetting.HarmBlockThreshold.OFF
78+
),
79+
SafetySetting(
80+
category=SafetySetting.HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
81+
threshold=SafetySetting.HarmBlockThreshold.OFF,
82+
),
5683
]
5784
__MODEL_PREFIX = "models/"
5885

@@ -63,6 +90,12 @@ def __init__(self, api_key: Optional[str] = None, is_gcp: bool = False):
6390
self.client = genai.Client(api_key=api_key)
6491
else:
6592
self.client = genai.Client(api_key=api_key, vertexai=True)
93+
location = os.environ.get("GOOGLE_CLOUD_LOCATION", "global")
94+
vertexai.init(
95+
project=os.environ.get("GOOGLE_CLOUD_PROJECT"),
96+
location=location,
97+
api_endpoint=f"{location}-aiplatform.googleapis.com",
98+
)
6699

67100
@lru_cache(maxsize=1)
68101
def __get_models_info(self) -> list[Model]:
@@ -173,6 +206,8 @@ def is_prompt_supported(
173206
top_p: Optional[float] | NotGiven = NOT_GIVEN,
174207
file: Path | NotGiven = NOT_GIVEN,
175208
) -> int:
209+
if self.__is_gcp:
210+
return 1
176211
system, contents = self.__openai_messages_to_google_messages(messages)
177212

178213
file_ref = self.__upload(file)
@@ -188,7 +223,12 @@ def is_prompt_supported(
188223
),
189224
)
190225
token_count = token_response.total_tokens
226+
except GoogleAuthError:
227+
raise
228+
except APIError:
229+
raise
191230
except Exception as e:
231+
logger.debug(f"Error during token count at GoogleLlmClient: {e}")
192232
return -1
193233
model_limit = self.__get_model_limits(model)
194234
return model_limit - token_count
@@ -255,15 +295,25 @@ def chat_completion(
255295
if file_ref is not None:
256296
contents.append(file_ref)
257297

258-
response = self.client.models.generate_content(
259-
model=model,
260-
contents=contents,
261-
config=GenerateContentConfig(
262-
system_instruction=system_content,
263-
safety_settings=self.__SAFETY_SETTINGS,
264-
**NotGiven.remove_not_given(generation_dict),
265-
),
266-
)
298+
if not self.__is_gcp:
299+
generate_content_func = partial(
300+
self.client.models.generate_content,
301+
model=model,
302+
config=GenerateContentConfig(
303+
system_instruction=system_content,
304+
safety_settings=self.__SAFETY_SETTINGS,
305+
**NotGiven.remove_not_given(generation_dict),
306+
),
307+
)
308+
else:
309+
vertexai_model = GenerativeModel(model, system_instruction=system_content)
310+
generate_content_func = partial(
311+
vertexai_model.generate_content,
312+
safety_settings=self.__VERTEX_SAFETY_SETTINGS,
313+
generation_config=NotGiven.remove_not_given(generation_dict),
314+
)
315+
316+
response = generate_content_func(contents=contents)
267317
return self.__google_response_to_openai_response(response, model)
268318

269319
@staticmethod

patchwork/common/utils/step_typing.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ def validate_step_type_config_with_inputs(
7676
or f"Missing required input data because {key_name} is set: {', '.join(missing_and_keys)}",
7777
)
7878

79-
or_keys = set(step_type_config.or_op)
80-
if len(or_keys) > 0:
81-
missing_or_keys = or_keys.difference(input_keys)
82-
if not is_key_set and len(missing_or_keys) == len(or_keys):
83-
return (
84-
False,
85-
step_type_config.msg
86-
or f"Missing required input: At least one of {', '.join(sorted([key_name, *or_keys]))} has to be set",
87-
)
79+
# or_keys = set(step_type_config.or_op)
80+
# if len(or_keys) > 0:
81+
# missing_or_keys = or_keys.difference(input_keys)
82+
# if not is_key_set and len(missing_or_keys) == len(or_keys):
83+
# return (
84+
# False,
85+
# step_type_config.msg
86+
# or f"Missing required input: At least one of {', '.join(sorted([key_name, *or_keys]))} has to be set",
87+
# )
8888

8989
xor_keys = set(step_type_config.xor_op)
9090
if len(xor_keys) > 0:

0 commit comments

Comments
 (0)