Skip to content

Commit aa010d2

Browse files
committed
release 0.3.0: add support for completion
1 parent 32ec8b6 commit aa010d2

File tree

12 files changed

+442
-27
lines changed

12 files changed

+442
-27
lines changed

examples/async_completion.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python
2+
3+
import asyncio
4+
import os
5+
6+
from mistralai.async_client import MistralAsyncClient
7+
8+
9+
async def main():
10+
api_key = os.environ["MISTRAL_API_KEY"]
11+
12+
client = MistralAsyncClient(api_key=api_key)
13+
14+
prompt = "def fibonacci(n: int):"
15+
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"
16+
17+
response = await client.completion(
18+
model="codestral-latest",
19+
prompt=prompt,
20+
suffix=suffix,
21+
)
22+
23+
print(
24+
f"""
25+
{prompt}
26+
{response.choices[0].message.content}
27+
{suffix}
28+
"""
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
asyncio.run(main())

examples/chatbot_with_streaming.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from mistralai.models.chat_completion import ChatMessage
1313

1414
MODEL_LIST = [
15-
"mistral-tiny",
16-
"mistral-small",
17-
"mistral-medium",
15+
"mistral-small-latest",
16+
"mistral-medium-latest",
17+
"mistral-large-latest",
18+
"codestral-latest",
1819
]
19-
DEFAULT_MODEL = "mistral-small"
20+
DEFAULT_MODEL = "mistral-small-latest"
2021
DEFAULT_TEMPERATURE = 0.7
2122
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
2223
# A dictionary of all commands and their arguments, used for tab completion.

examples/code_completion.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#!/usr/bin/env python
2+
3+
import asyncio
4+
import os
5+
6+
from mistralai.client import MistralClient
7+
8+
9+
async def main():
10+
api_key = os.environ["MISTRAL_API_KEY"]
11+
12+
client = MistralClient(api_key=api_key)
13+
14+
prompt = "def fibonacci(n: int):"
15+
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"
16+
17+
response = client.completion(
18+
model="codestral-latest",
19+
prompt=prompt,
20+
suffix=suffix,
21+
)
22+
23+
print(
24+
f"""
25+
{prompt}
26+
{response.choices[0].message.content}
27+
{suffix}
28+
"""
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
asyncio.run(main())
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#!/usr/bin/env python
2+
3+
import asyncio
4+
import os
5+
6+
from mistralai.client import MistralClient
7+
8+
9+
async def main():
10+
api_key = os.environ["MISTRAL_API_KEY"]
11+
12+
client = MistralClient(api_key=api_key)
13+
14+
prompt = "def fibonacci(n: int):"
15+
suffix = "n = int(input('Enter a number: '))\nprint(fibonacci(n))"
16+
17+
print(prompt)
18+
for chunk in client.completion_stream(
19+
model="codestral-latest",
20+
prompt=prompt,
21+
suffix=suffix,
22+
):
23+
if chunk.choices[0].delta.content is not None:
24+
print(chunk.choices[0].delta.content, end="")
25+
print(suffix)
26+
27+
28+
if __name__ == "__main__":
29+
asyncio.run(main())

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mistralai"
3-
version = "0.2.0"
3+
version = "0.3.0"
44
description = ""
55
authors = ["Bam4d <[email protected]>"]
66
readme = "README.md"

src/mistralai/async_client.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def _check_response(self, response: Response) -> Dict[str, Any]:
9292
async def _request(
9393
self,
9494
method: str,
95-
json: Dict[str, Any],
95+
json: Optional[Dict[str, Any]],
9696
path: str,
9797
stream: bool = False,
9898
attempt: int = 1,
@@ -291,3 +291,74 @@ async def list_models(self) -> ModelList:
291291
return ModelList(**response)
292292

293293
raise MistralException("No response received")
294+
295+
async def completion(
296+
self,
297+
model: str,
298+
prompt: str,
299+
suffix: Optional[str] = None,
300+
temperature: Optional[float] = None,
301+
max_tokens: Optional[int] = None,
302+
top_p: Optional[float] = None,
303+
random_seed: Optional[int] = None,
304+
stop: Optional[List[str]] = None,
305+
) -> ChatCompletionResponse:
306+
"""An asynchronous completion endpoint that returns a single response.
307+
308+
Args:
309+
model (str): model the name of the model to get completions with, e.g. codestral-latest
310+
prompt (str): the prompt to complete
311+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
312+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
313+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
314+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
315+
Defaults to None.
316+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
317+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
318+
Returns:
319+
Dict[str, Any]: a response object containing the generated text.
320+
"""
321+
request = self._make_completion_request(
322+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
323+
)
324+
single_response = self._request("post", request, "v1/fim/completions")
325+
326+
async for response in single_response:
327+
return ChatCompletionResponse(**response)
328+
329+
raise MistralException("No response received")
330+
331+
async def completion_stream(
332+
self,
333+
model: str,
334+
prompt: str,
335+
suffix: Optional[str] = None,
336+
temperature: Optional[float] = None,
337+
max_tokens: Optional[int] = None,
338+
top_p: Optional[float] = None,
339+
random_seed: Optional[int] = None,
340+
stop: Optional[List[str]] = None,
341+
) -> AsyncGenerator[ChatCompletionStreamResponse, None]:
342+
"""An asynchronous completion endpoint that returns a streaming response.
343+
344+
Args:
345+
model (str): model the name of the model to get completions with, e.g. codestral-latest
346+
prompt (str): the prompt to complete
347+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
348+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
349+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
350+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
351+
Defaults to None.
352+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
353+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
354+
355+
Returns:
356+
Dict[str, Any]: a response object containing the generated text.
357+
"""
358+
request = self._make_completion_request(
359+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
360+
)
361+
async_response = self._request("post", request, "v1/fim/completions", stream=True)
362+
363+
async for json_response in async_response:
364+
yield ChatCompletionStreamResponse(**json_response)

src/mistralai/client.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _check_response(self, response: Response) -> Dict[str, Any]:
8585
def _request(
8686
self,
8787
method: str,
88-
json: Dict[str, Any],
88+
json: Optional[Dict[str, Any]],
8989
path: str,
9090
stream: bool = False,
9191
attempt: int = 1,
@@ -285,3 +285,77 @@ def list_models(self) -> ModelList:
285285
return ModelList(**response)
286286

287287
raise MistralException("No response received")
288+
289+
def completion(
290+
self,
291+
model: str,
292+
prompt: str,
293+
suffix: Optional[str] = None,
294+
temperature: Optional[float] = None,
295+
max_tokens: Optional[int] = None,
296+
top_p: Optional[float] = None,
297+
random_seed: Optional[int] = None,
298+
stop: Optional[List[str]] = None,
299+
) -> ChatCompletionResponse:
300+
"""A completion endpoint that returns a single response.
301+
302+
Args:
303+
model (str): model the name of the model to get completion with, e.g. codestral-latest
304+
prompt (str): the prompt to complete
305+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
306+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
307+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
308+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
309+
Defaults to None.
310+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
311+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
312+
313+
Returns:
314+
Dict[str, Any]: a response object containing the generated text.
315+
"""
316+
request = self._make_completion_request(
317+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop
318+
)
319+
320+
single_response = self._request("post", request, "v1/fim/completions", stream=False)
321+
322+
for response in single_response:
323+
return ChatCompletionResponse(**response)
324+
325+
raise MistralException("No response received")
326+
327+
def completion_stream(
328+
self,
329+
model: str,
330+
prompt: str,
331+
suffix: Optional[str] = None,
332+
temperature: Optional[float] = None,
333+
max_tokens: Optional[int] = None,
334+
top_p: Optional[float] = None,
335+
random_seed: Optional[int] = None,
336+
stop: Optional[List[str]] = None,
337+
) -> Iterable[ChatCompletionStreamResponse]:
338+
"""An asynchronous completion endpoint that streams responses.
339+
340+
Args:
341+
model (str): model the name of the model to get completions with, e.g. codestral-latest
342+
prompt (str): the prompt to complete
343+
suffix (Optional[str]): the suffix to append to the prompt for fill-in-the-middle completion
344+
temperature (Optional[float], optional): temperature the temperature to use for sampling, e.g. 0.5.
345+
max_tokens (Optional[int], optional): the maximum number of tokens to generate, e.g. 100. Defaults to None.
346+
top_p (Optional[float], optional): the cumulative probability of tokens to generate, e.g. 0.9.
347+
Defaults to None.
348+
random_seed (Optional[int], optional): the random seed to use for sampling, e.g. 42. Defaults to None.
349+
stop (Optional[List[str]], optional): a list of tokens to stop generation at, e.g. ['/n/n']
350+
351+
Returns:
352+
Iterable[Dict[str, Any]]: a generator that yields response objects containing the generated text.
353+
"""
354+
request = self._make_completion_request(
355+
prompt, model, suffix, temperature, max_tokens, top_p, random_seed, stop, stream=True
356+
)
357+
358+
response = self._request("post", request, "v1/fim/completions", stream=True)
359+
360+
for json_streamed_response in response:
361+
yield ChatCompletionStreamResponse(**json_streamed_response)

src/mistralai/client_base.py

Lines changed: 63 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,63 @@ def _parse_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
7373

7474
return parsed_messages
7575

76+
def _make_completion_request(
77+
self,
78+
prompt: str,
79+
model: Optional[str] = None,
80+
suffix: Optional[str] = None,
81+
temperature: Optional[float] = None,
82+
max_tokens: Optional[int] = None,
83+
top_p: Optional[float] = None,
84+
random_seed: Optional[int] = None,
85+
stop: Optional[List[str]] = None,
86+
stream: Optional[bool] = False,
87+
) -> Dict[str, Any]:
88+
request_data: Dict[str, Any] = {
89+
"prompt": prompt,
90+
"suffix": suffix,
91+
"model": model,
92+
"stream": stream,
93+
}
94+
95+
if stop is not None:
96+
request_data["stop"] = stop
97+
98+
if model is not None:
99+
request_data["model"] = model
100+
else:
101+
if self._default_model is None:
102+
raise MistralException(message="model must be provided")
103+
request_data["model"] = self._default_model
104+
105+
request_data.update(
106+
self._build_sampling_params(
107+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
108+
)
109+
)
110+
111+
self._logger.debug(f"Completion request: {request_data}")
112+
113+
return request_data
114+
115+
def _build_sampling_params(
116+
self,
117+
max_tokens: Optional[int],
118+
random_seed: Optional[int],
119+
temperature: Optional[float],
120+
top_p: Optional[float],
121+
) -> Dict[str, Any]:
122+
params = {}
123+
if temperature is not None:
124+
params["temperature"] = temperature
125+
if max_tokens is not None:
126+
params["max_tokens"] = max_tokens
127+
if top_p is not None:
128+
params["top_p"] = top_p
129+
if random_seed is not None:
130+
params["random_seed"] = random_seed
131+
return params
132+
76133
def _make_chat_request(
77134
self,
78135
messages: List[Any],
@@ -99,16 +156,14 @@ def _make_chat_request(
99156
raise MistralException(message="model must be provided")
100157
request_data["model"] = self._default_model
101158

159+
request_data.update(
160+
self._build_sampling_params(
161+
temperature=temperature, max_tokens=max_tokens, top_p=top_p, random_seed=random_seed
162+
)
163+
)
164+
102165
if tools is not None:
103166
request_data["tools"] = self._parse_tools(tools)
104-
if temperature is not None:
105-
request_data["temperature"] = temperature
106-
if max_tokens is not None:
107-
request_data["max_tokens"] = max_tokens
108-
if top_p is not None:
109-
request_data["top_p"] = top_p
110-
if random_seed is not None:
111-
request_data["random_seed"] = random_seed
112167
if stream is not None:
113168
request_data["stream"] = stream
114169

0 commit comments

Comments
 (0)