Skip to content

Commit 6abb672

Browse files
committed
update
1 parent e25d8b9 commit 6abb672

File tree

6 files changed

+46
-37
lines changed

6 files changed

+46
-37
lines changed

patchwork/common/client/llm/aio.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def __init__(self, *clients: LlmClient):
3131
self.__supported_models = set()
3232
for client in clients:
3333
try:
34-
self.__supported_models.update(client.get_models())
34+
client.test()
3535
self.__clients.append(client)
36-
except Exception:
36+
except Exception as e:
37+
logger.error(f"{client.__class__.__name__} Failed with exception: {e}")
3738
pass
3839

3940
def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
@@ -45,6 +46,9 @@ def __get_model(self, model_settings: ModelSettings | None) -> Optional[str]:
4546

4647
return model_name
4748

49+
def test(self) -> None:
50+
pass
51+
4852
async def request(
4953
self,
5054
messages: list[ModelMessage],
@@ -94,9 +98,6 @@ def model_name(self) -> str:
9498
def system(self) -> str:
9599
return next(iter(self.__clients)).system
96100

97-
def get_models(self) -> set[str]:
98-
return self.__supported_models
99-
100101
def is_model_supported(self, model: str) -> bool:
101102
return any(client.is_model_supported(model) for client in self.__clients)
102103

@@ -204,6 +205,8 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
204205
clients = []
205206

206207
client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
208+
if os.environ.get("GOOGLE_GENAI_USE_VERTEXAI") == "true":
209+
client_args["is_gcp"] = True
207210

208211
patched_key = inputs.get("patched_api_key")
209212
if patched_key is not None:
@@ -216,8 +219,8 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
216219
clients.append(client)
217220

218221
google_key = inputs.get("google_api_key")
219-
if google_key is not None:
220-
client = GoogleLlmClient(google_key, **client_args)
222+
if google_key is not None or "is_gcp" in client_args.keys():
223+
client = GoogleLlmClient(api_key=google_key, is_gcp=bool(client_args.get("is_gcp", False)))
221224
clients.append(client)
222225

223226
anthropic_key = inputs.get("anthropic_api_key")

patchwork/common/client/llm/anthropic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,8 @@ def __adapt_chat_completion_request(
245245

246246
return NotGiven.remove_not_given(input_kwargs)
247247

248-
@lru_cache(maxsize=None)
249-
def get_models(self) -> set[str]:
250-
return self.__definitely_allowed_models.union(set(f"{self.__allowed_model_prefix}*"))
248+
def test(self):
249+
return
251250

252251
def is_model_supported(self, model: str) -> bool:
253252
return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix)

patchwork/common/client/llm/google_.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,17 +56,20 @@ class GoogleLlmClient(LlmClient):
5656
]
5757
__MODEL_PREFIX = "models/"
5858

59-
def __init__(self, api_key: str, is_gcp: bool = False):
59+
def __init__(self, api_key: Optional[str] = None, is_gcp: bool = False):
6060
self.__api_key = api_key
6161
self.__is_gcp = is_gcp
62-
if not is_gcp:
62+
if not self.__is_gcp:
6363
self.client = genai.Client(api_key=api_key)
6464
else:
65-
self.client = genai.Client(api_key=api_key, vertexai=True, credentials=Credentials())
65+
self.client = genai.Client(api_key=api_key, vertexai=True)
6666

6767
@lru_cache(maxsize=1)
6868
def __get_models_info(self) -> list[Model]:
69-
return list(self.client.models.list())
69+
if not self.__is_gcp:
70+
return list(self.client.models.list())
71+
else:
72+
return list()
7073

7174
def __get_pydantic_model(self, model_settings: ModelSettings | None) -> PydanticAiModel:
7275
if model_settings is None:
@@ -112,12 +115,15 @@ def __get_model_limits(self, model: str) -> int:
112115
return model_info.input_token_limit
113116
return 1_000_000
114117

115-
@lru_cache
116-
def get_models(self) -> set[str]:
117-
return {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
118+
def test(self):
119+
return
118120

119121
def is_model_supported(self, model: str) -> bool:
120-
return model in self.get_models()
122+
if not self.__is_gcp:
123+
model_names = {model_info.name.removeprefix(self.__MODEL_PREFIX) for model_info in self.__get_models_info()}
124+
return model in model_names
125+
else:
126+
return True
121127

122128
def __upload(self, file: Path | NotGiven) -> Part | File | None:
123129
if isinstance(file, NotGiven):

patchwork/common/client/llm/openai_.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,18 @@ def __is_not_openai_url(self):
9696
# We mainly use this to skip using the model endpoints.
9797
return self.__base_url is not None and self.__base_url != "https://api.openai.com/v1"
9898

99-
def get_models(self) -> set[str]:
99+
def test(self):
100100
if self.__is_not_openai_url():
101-
return set()
101+
return
102102

103-
return _cached_list_models_from_openai(self.__api_key)
103+
_cached_list_models_from_openai(self.__api_key)
104+
return
104105

105106
def is_model_supported(self, model: str) -> bool:
106107
# might not implement model endpoint
107108
if self.__is_not_openai_url():
108109
return True
109-
return model in self.get_models()
110+
return model in _cached_list_models_from_openai(self.__api_key)
110111

111112
def __get_model_limits(self, model: str) -> int:
112113
return self.__MODEL_LIMITS.get(model, 128_000)

patchwork/common/client/llm/protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def remove_not_given(obj: Any) -> Union[None, dict[Any, Any], list[Any], Any]:
3333

3434
class LlmClient(Model):
3535
@abstractmethod
36-
def get_models(self) -> set[str]:
36+
def test(self) -> None:
3737
...
3838

3939
@abstractmethod

poetry.lock

Lines changed: 14 additions & 14 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)