|
1 | 1 | """Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT.""" |
2 | 2 |
|
3 | 3 | import json |
4 | | -from typing import Optional, Union |
| 4 | +from typing import Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import google.auth |
7 | 7 | import google.auth.credentials |
|
20 | 20 | from .utils.logger import Logger, NoOpLogger |
21 | 21 | from .utils.retries import RetryConfig |
22 | 22 |
|
| 23 | +LEGACY_MODEL_ID_FORMAT = { |
| 24 | + "codestral-2405": "codestral@2405", |
| 25 | + "mistral-large-2407": "mistral-large@2407", |
| 26 | + "mistral-nemo-2407": "mistral-nemo@2407", |
| 27 | +} |
| 28 | + |
| 29 | +def get_model_info(model: str) -> Tuple[str,str]: |
| 30 | + # if the model requiers the legacy fomat, use it, else do nothing. |
| 31 | + model_id = LEGACY_MODEL_ID_FORMAT.get(model, model) |
| 32 | + model = "-".join(model.split("-")[:-1]) |
| 33 | + return model, model_id |
| 34 | + |
| 35 | + |
23 | 36 |
|
24 | 37 | class MistralGoogleCloud(BaseSDK): |
25 | 38 | r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it.""" |
@@ -140,28 +153,24 @@ def __init__(self, region: str, project_id: str): |
140 | 153 | def before_request( |
141 | 154 | self, hook_ctx, request: httpx.Request |
142 | 155 | ) -> Union[httpx.Request, Exception]: |
143 | | - # The goal of this function is to template in the region, project, model, and model_version into the URL path |
| 156 | + # The goal of this function is to template in the region, project and model into the URL path |
144 | 157 | # We do this here so that the API remains more user-friendly |
145 | | - model = None |
146 | | - model_version = None |
| 158 | + model_id = None |
147 | 159 | new_content = None |
148 | 160 | if request.content: |
149 | 161 | parsed = json.loads(request.content.decode("utf-8")) |
150 | 162 | model_raw = parsed.get("model") |
151 | | - model = "-".join(model_raw.split("-")[:-1]) |
152 | | - model_version = model_raw.split("-")[-1] |
153 | | - parsed["model"] = model |
| 163 | + model_name, model_id = get_model_info(model_raw) |
| 164 | + parsed["model"] = model_name |
154 | 165 | new_content = json.dumps(parsed).encode("utf-8") |
155 | 166 |
|
156 | | - if model == "": |
| 167 | + if model_id == "": |
157 | 168 | raise models.SDKError("model must be provided") |
158 | 169 |
|
159 | | - if model_version is None: |
160 | | - raise models.SDKError("model_version must be provided") |
161 | 170 |
|
162 | 171 | stream = "streamRawPredict" in request.url.path |
163 | 172 | specifier = "streamRawPredict" if stream else "rawPredict" |
164 | | - url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model}@{model_version}:{specifier}" |
| 173 | + url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}" |
165 | 174 |
|
166 | 175 | headers = dict(request.headers) |
167 | 176 | # Delete content-length header as it will need to be recalculated |
|
0 commit comments