Skip to content

Commit 5fdc57c

Browse files
authored
feat: add warning log when targeting a model planned for deprecation (#111)
* add model deprecation warning * use Optional
1 parent af48070 commit 5fdc57c

File tree

7 files changed

+288
-55
lines changed

7 files changed

+288
-55
lines changed

src/mistralai/async_client.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import posixpath
33
from json import JSONDecodeError
4-
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
4+
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Union
55

66
from httpx import (
77
AsyncClient,
@@ -101,6 +101,7 @@ async def _request(
101101
stream: bool = False,
102102
attempt: int = 1,
103103
data: Optional[Dict[str, Any]] = None,
104+
check_model_deprecation_headers_callback: Optional[Callable] = None,
104105
**kwargs: Any,
105106
) -> AsyncGenerator[Dict[str, Any], None]:
106107
accept_header = "text/event-stream" if stream else "application/json"
@@ -129,6 +130,8 @@ async def _request(
129130
data=data,
130131
**kwargs,
131132
) as response:
133+
if check_model_deprecation_headers_callback:
134+
check_model_deprecation_headers_callback(response.headers)
132135
await self._check_streaming_response(response)
133136

134137
async for line in response.aiter_lines():
@@ -145,7 +148,8 @@ async def _request(
145148
data=data,
146149
**kwargs,
147150
)
148-
151+
if check_model_deprecation_headers_callback:
152+
check_model_deprecation_headers_callback(response.headers)
149153
yield await self._check_response(response)
150154

151155
except ConnectError as e:
@@ -213,7 +217,12 @@ async def chat(
213217
response_format=response_format,
214218
)
215219

216-
single_response = self._request("post", request, "v1/chat/completions")
220+
single_response = self._request(
221+
"post",
222+
request,
223+
"v1/chat/completions",
224+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
225+
)
217226

218227
async for response in single_response:
219228
return ChatCompletionResponse(**response)
@@ -267,7 +276,13 @@ async def chat_stream(
267276
tool_choice=tool_choice,
268277
response_format=response_format,
269278
)
270-
async_response = self._request("post", request, "v1/chat/completions", stream=True)
279+
async_response = self._request(
280+
"post",
281+
request,
282+
"v1/chat/completions",
283+
stream=True,
284+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
285+
)
271286

272287
async for json_response in async_response:
273288
yield ChatCompletionStreamResponse(**json_response)
@@ -284,7 +299,12 @@ async def embeddings(self, model: str, input: Union[str, List[str]]) -> Embeddin
284299
EmbeddingResponse: A response object containing the embeddings.
285300
"""
286301
request = {"model": model, "input": input}
287-
single_response = self._request("post", request, "v1/embeddings")
302+
single_response = self._request(
303+
"post",
304+
request,
305+
"v1/embeddings",
306+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
307+
)
288308

289309
async for response in single_response:
290310
return EmbeddingResponse(**response)
@@ -341,7 +361,12 @@ async def completion(
341361
request = self._make_completion_request(
342362
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
343363
)
344-
single_response = self._request("post", request, "v1/fim/completions")
364+
single_response = self._request(
365+
"post",
366+
request,
367+
"v1/fim/completions",
368+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
369+
)
345370

346371
async for response in single_response:
347372
return ChatCompletionResponse(**response)
@@ -376,9 +401,23 @@ async def completion_stream(
376401
Dict[str, Any]: a response object containing the generated text.
377402
"""
378403
request = self._make_completion_request(
379-
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
404+
prompt,
405+
model,
406+
suffix,
407+
temperature,
408+
max_tokens,
409+
top_p,
410+
random_seed,
411+
stop,
412+
stream=True,
413+
)
414+
async_response = self._request(
415+
"post",
416+
request,
417+
"v1/fim/completions",
418+
stream=True,
419+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
380420
)
381-
async_response = self._request("post", request, "v1/fim/completions", stream=True)
382421

383422
async for json_response in async_response:
384423
yield ChatCompletionStreamResponse(**json_response)

src/mistralai/client.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import posixpath
22
import time
33
from json import JSONDecodeError
4-
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
4+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Union
55

66
from httpx import Client, ConnectError, HTTPTransport, RequestError, Response
77

@@ -40,7 +40,9 @@ def __init__(
4040
super().__init__(endpoint, api_key, max_retries, timeout)
4141

4242
self._client = Client(
43-
follow_redirects=True, timeout=self._timeout, transport=HTTPTransport(retries=self._max_retries)
43+
follow_redirects=True,
44+
timeout=self._timeout,
45+
transport=HTTPTransport(retries=self._max_retries),
4446
)
4547
self.files = FilesClient(self)
4648
self.jobs = JobsClient(self)
@@ -94,6 +96,7 @@ def _request(
9496
stream: bool = False,
9597
attempt: int = 1,
9698
data: Optional[Dict[str, Any]] = None,
99+
check_model_deprecation_headers_callback: Optional[Callable] = None,
97100
**kwargs: Any,
98101
) -> Iterator[Dict[str, Any]]:
99102
accept_header = "text/event-stream" if stream else "application/json"
@@ -122,6 +125,8 @@ def _request(
122125
data=data,
123126
**kwargs,
124127
) as response:
128+
if check_model_deprecation_headers_callback:
129+
check_model_deprecation_headers_callback(response.headers)
125130
self._check_streaming_response(response)
126131

127132
for line in response.iter_lines():
@@ -138,7 +143,8 @@ def _request(
138143
data=data,
139144
**kwargs,
140145
)
141-
146+
if check_model_deprecation_headers_callback:
147+
check_model_deprecation_headers_callback(response.headers)
142148
yield self._check_response(response)
143149

144150
except ConnectError as e:
@@ -207,7 +213,12 @@ def chat(
207213
response_format=response_format,
208214
)
209215

210-
single_response = self._request("post", request, "v1/chat/completions")
216+
single_response = self._request(
217+
"post",
218+
request,
219+
"v1/chat/completions",
220+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
221+
)
211222

212223
for response in single_response:
213224
return ChatCompletionResponse(**response)
@@ -261,7 +272,13 @@ def chat_stream(
261272
response_format=response_format,
262273
)
263274

264-
response = self._request("post", request, "v1/chat/completions", stream=True)
275+
response = self._request(
276+
"post",
277+
request,
278+
"v1/chat/completions",
279+
stream=True,
280+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
281+
)
265282

266283
for json_streamed_response in response:
267284
yield ChatCompletionStreamResponse(**json_streamed_response)
@@ -278,7 +295,12 @@ def embeddings(self, model: str, input: Union[str, List[str]]) -> EmbeddingRespo
278295
EmbeddingResponse: A response object containing the embeddings.
279296
"""
280297
request = {"model": model, "input": input}
281-
singleton_response = self._request("post", request, "v1/embeddings")
298+
singleton_response = self._request(
299+
"post",
300+
request,
301+
"v1/embeddings",
302+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
303+
)
282304

283305
for response in singleton_response:
284306
return EmbeddingResponse(**response)
@@ -337,7 +359,13 @@ def completion(
337359
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
338360
)
339361

340-
single_response = self._request("post", request, "v1/fim/completions", stream=False)
362+
single_response = self._request(
363+
"post",
364+
request,
365+
"v1/fim/completions",
366+
stream=False,
367+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
368+
)
341369

342370
for response in single_response:
343371
return ChatCompletionResponse(**response)
@@ -372,10 +400,24 @@ def completion_stream(
372400
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
373401
"""
374402
request = self._make_completion_request(
375-
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
403+
prompt,
404+
model,
405+
suffix,
406+
temperature,
407+
max_tokens,
408+
top_p,
409+
random_seed,
410+
stop,
411+
stream=True,
376412
)
377413

378-
response = self._request("post", request, "v1/fim/completions", stream=True)
414+
response = self._request(
415+
"post",
416+
request,
417+
"v1/fim/completions",
418+
stream=True,
419+
check_model_deprecation_headers_callback=self._check_model_deprecation_header_callback_factory(model),
420+
)
379421

380422
for json_streamed_response in response:
381423
yield ChatCompletionStreamResponse(**json_streamed_response)

src/mistralai/client_base.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
import logging
22
import os
33
from abc import ABC
4-
from typing import Any, Dict, List, Optional, Union
4+
from typing import Any, Callable, Dict, List, Optional, Union
55

66
import orjson
7-
8-
from mistralai.exceptions import (
9-
MistralException,
7+
from httpx import Headers
8+
9+
from mistralai.constants import HEADER_MODEL_DEPRECATION_TIMESTAMP
10+
from mistralai.exceptions import MistralException
11+
from mistralai.models.chat_completion import (
12+
ChatMessage,
13+
Function,
14+
ResponseFormat,
15+
ToolChoice,
1016
)
11-
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
1217

1318
CLIENT_VERSION = "0.4.1"
1419

@@ -38,6 +43,14 @@ def __init__(
3843

3944
self._version = CLIENT_VERSION
4045

46+
def _get_model(self, model: Optional[str] = None) -> str:
47+
if model is not None:
48+
return model
49+
else:
50+
if self._default_model is None:
51+
raise MistralException(message="model must be provided")
52+
return self._default_model
53+
4154
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
4255
parsed_tools: List[Dict[str, Any]] = []
4356
for tool in tools:
@@ -73,6 +86,22 @@ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
7386

7487
return parsed_messages
7588

89+
def _check_model_deprecation_header_callback_factory(self, model: Optional[str] = None) -> Callable:
90+
model = self._get_model(model)
91+
92+
def _check_model_deprecation_header_callback(
93+
headers: Headers,
94+
) -> None:
95+
if HEADER_MODEL_DEPRECATION_TIMESTAMP in headers:
96+
self._logger.warning(
97+
f"WARNING: The model {model} is deprecated "
98+
f"and will be removed on {headers[HEADER_MODEL_DEPRECATION_TIMESTAMP]}. "
99+
"Please refer to https://docs.mistral.ai/getting-started/models/#api-versioning "
100+
"for more information."
101+
)
102+
103+
return _check_model_deprecation_header_callback
104+
76105
def _make_completion_request(
77106
self,
78107
prompt: str,
@@ -95,16 +124,14 @@ def _make_completion_request(
95124
if stop is not None:
96125
request_data["stop"] = stop
97126

98-
if model is not None:
99-
request_data["model"] = model
100-
else:
101-
if self._default_model is None:
102-
raise MistralException(message="model must be provided")
103-
request_data["model"] = self._default_model
127+
request_data["model"] = self._get_model(model)
104128

105129
request_data.update(
106130
self._build_sampling_params(
107-
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
131+
temperature=temperature,
132+
max_tokens=max_tokens,
133+
top_p=top_p,
134+
random_seed=random_seed,
108135
)
109136
)
110137

@@ -148,16 +175,14 @@ def _make_chat_request(
148175
"messages": self._parse_messages(messages),
149176
}
150177

151-
if model is not None:
152-
request_data["model"] = model
153-
else:
154-
if self._default_model is None:
155-
raise MistralException(message="model must be provided")
156-
request_data["model"] = self._default_model
178+
request_data["model"] = self._get_model(model)
157179

158180
request_data.update(
159181
self._build_sampling_params(
160-
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
182+
temperature=temperature,
183+
max_tokens=max_tokens,
184+
top_p=top_p,
185+
random_seed=random_seed,
161186
)
162187
)
163188

src/mistralai/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
22

33
ENDPOINT = "https://api.mistral.ai"
4+
5+
HEADER_MODEL_DEPRECATION_TIMESTAMP = "x-model-deprecation-timestamp"

0 commit comments

Comments
 (0)