Skip to content

Commit 32ec8b6

Browse files
authored
Release 0.2.0 (#94)
1 parent 374fd01 commit 32ec8b6

19 files changed

+189
-201
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
name: Lint / Test / Publish
2+
3+
on:
4+
push:
5+
branches: ["main"]
6+
7+
# We only deploy on tags and main branch
8+
tags:
9+
# Only run on tags that match the following regex
10+
# This will match tags like 1.0.0, 1.0.1, etc.
11+
- "[0-9]+.[0-9]+.[0-9]+"
12+
13+
# Lint and test on pull requests
14+
pull_request:
15+
16+
jobs:
17+
lint_and_test:
18+
runs-on: ubuntu-latest
19+
strategy:
20+
matrix:
21+
python-version: ["3.9", "3.10", "3.11", "3.12"]
22+
steps:
23+
# Checkout the repository
24+
- name: Checkout
25+
uses: actions/checkout@v4
26+
27+
# Set python version to 3.11
28+
- name: set python version
29+
uses: actions/setup-python@v4
30+
with:
31+
python-version: ${{ matrix.python-version }}
32+
33+
# Install Build stuff
34+
- name: Install Dependencies
35+
run: |
36+
pip install poetry \
37+
&& poetry config virtualenvs.create false \
38+
&& poetry install
39+
40+
# Ruff
41+
- name: Ruff check
42+
run: |
43+
poetry run ruff check .
44+
45+
- name: Ruff check
46+
run: |
47+
poetry run ruff format . --check
48+
49+
# Mypy
50+
- name: Mypy Check
51+
run: |
52+
poetry run mypy .
53+
54+
# Tests
55+
- name: Run Tests
56+
run: |
57+
poetry run pytest .
58+
59+
publish:
60+
if: startsWith(github.ref, 'refs/tags')
61+
runs-on: ubuntu-latest
62+
needs: lint_and_test
63+
steps:
64+
# Checkout the repository
65+
- name: Checkout
66+
uses: actions/checkout@v4
67+
68+
# Set python version to 3.11
69+
- name: set python version
70+
uses: actions/setup-python@v4
71+
with:
72+
python-version: 3.11
73+
74+
# Install Build stuff
75+
- name: Install Dependencies
76+
run: |
77+
pip install poetry \
78+
&& poetry config virtualenvs.create false \
79+
&& poetry install
80+
81+
# build package using poetry
82+
- name: Build Package
83+
run: |
84+
poetry build
85+
86+
# Publish to PyPi
87+
- name: Pypi publish
88+
run: |
89+
poetry config pypi-token.pypi ${{ secrets.PYPI_TOKEN }}
90+
poetry publish

examples/chatbot_with_streaming.py

Lines changed: 9 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def completer(text, state):
6363

6464

6565
class ChatBot:
66-
def __init__(
67-
self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE
68-
):
66+
def __init__(self, api_key, model, system_message=None, temperature=DEFAULT_TEMPERATURE):
6967
if not api_key:
7068
raise ValueError("An API key must be provided to use the Mistral API.")
7169
self.client = MistralClient(api_key=api_key)
@@ -89,15 +87,11 @@ def opening_instructions(self):
8987

9088
def new_chat(self):
9189
print("")
92-
print(
93-
f"Starting new chat with model: {self.model}, temperature: {self.temperature}"
94-
)
90+
print(f"Starting new chat with model: {self.model}, temperature: {self.temperature}")
9591
print("")
9692
self.messages = []
9793
if self.system_message:
98-
self.messages.append(
99-
ChatMessage(role="system", content=self.system_message)
100-
)
94+
self.messages.append(ChatMessage(role="system", content=self.system_message))
10195

10296
def switch_model(self, input):
10397
model = self.get_arguments(input)
@@ -146,13 +140,9 @@ def run_inference(self, content):
146140
self.messages.append(ChatMessage(role="user", content=content))
147141

148142
assistant_response = ""
149-
logger.debug(
150-
f"Running inference with model: {self.model}, temperature: {self.temperature}"
151-
)
143+
logger.debug(f"Running inference with model: {self.model}, temperature: {self.temperature}")
152144
logger.debug(f"Sending messages: {self.messages}")
153-
for chunk in self.client.chat_stream(
154-
model=self.model, temperature=self.temperature, messages=self.messages
155-
):
145+
for chunk in self.client.chat_stream(model=self.model, temperature=self.temperature, messages=self.messages):
156146
response = chunk.choices[0].delta.content
157147
if response is not None:
158148
print(response, end="", flush=True)
@@ -161,9 +151,7 @@ def run_inference(self, content):
161151
print("", flush=True)
162152

163153
if assistant_response:
164-
self.messages.append(
165-
ChatMessage(role="assistant", content=assistant_response)
166-
)
154+
self.messages.append(ChatMessage(role="assistant", content=assistant_response))
167155
logger.debug(f"Current messages: {self.messages}")
168156

169157
def get_command(self, input):
@@ -215,9 +203,7 @@ def exit(self):
215203

216204

217205
if __name__ == "__main__":
218-
parser = argparse.ArgumentParser(
219-
description="A simple chatbot using the Mistral API"
220-
)
206+
parser = argparse.ArgumentParser(description="A simple chatbot using the Mistral API")
221207
parser.add_argument(
222208
"--api-key",
223209
default=os.environ.get("MISTRAL_API_KEY"),
@@ -230,19 +216,15 @@ def exit(self):
230216
default=DEFAULT_MODEL,
231217
help="Model for chat inference. Choices are %(choices)s. Defaults to %(default)s",
232218
)
233-
parser.add_argument(
234-
"-s", "--system-message", help="Optional system message to prepend."
235-
)
219+
parser.add_argument("-s", "--system-message", help="Optional system message to prepend.")
236220
parser.add_argument(
237221
"-t",
238222
"--temperature",
239223
type=float,
240224
default=DEFAULT_TEMPERATURE,
241225
help="Optional temperature for chat inference. Defaults to %(default)s",
242226
)
243-
parser.add_argument(
244-
"-d", "--debug", action="store_true", help="Enable debug logging"
245-
)
227+
parser.add_argument("-d", "--debug", action="store_true", help="Enable debug logging")
246228

247229
args = parser.parse_args()
248230

examples/function_calling.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,26 @@
1515
"payment_status": ["Paid", "Unpaid", "Paid", "Paid", "Pending"],
1616
}
1717

18-
def retrieve_payment_status(data: Dict[str,List], transaction_id: str) -> str:
18+
19+
def retrieve_payment_status(data: Dict[str, List], transaction_id: str) -> str:
1920
for i, r in enumerate(data["transaction_id"]):
2021
if r == transaction_id:
2122
return json.dumps({"status": data["payment_status"][i]})
2223
else:
2324
return json.dumps({"status": "Error - transaction id not found"})
2425

26+
2527
def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
2628
for i, r in enumerate(data["transaction_id"]):
2729
if r == transaction_id:
2830
return json.dumps({"date": data["payment_date"][i]})
2931
else:
3032
return json.dumps({"status": "Error - transaction id not found"})
3133

34+
3235
names_to_functions = {
33-
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
34-
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data)
36+
"retrieve_payment_status": functools.partial(retrieve_payment_status, data=data),
37+
"retrieve_payment_date": functools.partial(retrieve_payment_date, data=data),
3538
}
3639

3740
tools = [
@@ -75,9 +78,7 @@ def retrieve_payment_date(data: Dict[str, List], transaction_id: str) -> str:
7578
messages.append(ChatMessage(role="assistant", content=response.choices[0].message.content))
7679
messages.append(ChatMessage(role="user", content="My transaction ID is T1001."))
7780

78-
response = client.chat(
79-
model=model, messages=messages, tools=tools
80-
)
81+
response = client.chat(model=model, messages=messages, tools=tools)
8182

8283
tool_call = response.choices[0].message.tool_calls[0]
8384
function_name = tool_call.function.name

examples/json_format.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def main():
1616
model=model,
1717
response_format={"type": "json_object"},
1818
messages=[ChatMessage(role="user", content="What is the best French cheese? Answer shortly in JSON.")],
19-
2019
)
2120
print(chat_response.choices[0].message.content)
2221

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mistralai"
3-
version = "0.0.1"
3+
version = "0.2.0"
44
description = ""
55
authors = ["Bam4d <[email protected]>"]
66
readme = "README.md"

src/mistralai/async_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import os
32
import posixpath
43
from json import JSONDecodeError
54
from typing import Any, AsyncGenerator, Dict, List, Optional, Union

src/mistralai/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import posixpath
32
import time
43
from json import JSONDecodeError

src/mistralai/client_base.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
)
1111
from mistralai.models.chat_completion import ChatMessage, Function, ResponseFormat, ToolChoice
1212

13+
CLIENT_VERSION = "0.2.0"
14+
1315

1416
class ClientBase(ABC):
1517
def __init__(
@@ -25,9 +27,7 @@ def __init__(
2527
if api_key is None:
2628
api_key = os.environ.get("MISTRAL_API_KEY")
2729
if api_key is None:
28-
raise MistralException(
29-
message="API key not provided. Please set MISTRAL_API_KEY environment variable."
30-
)
30+
raise MistralException(message="API key not provided. Please set MISTRAL_API_KEY environment variable.")
3131
self._api_key = api_key
3232
self._endpoint = endpoint
3333
self._logger = logging.getLogger(__name__)
@@ -36,8 +36,7 @@ def __init__(
3636
if "inference.azure.com" in self._endpoint:
3737
self._default_model = "mistral"
3838

39-
# This should be automatically updated by the deploy script
40-
self._version = "0.0.1"
39+
self._version = CLIENT_VERSION
4140

4241
def _parse_tools(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
4342
parsed_tools: List[Dict[str, Any]] = []

src/mistralai/constants.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
2-
31
RETRY_STATUS_CODES = {429, 500, 502, 503, 504}
42

53
ENDPOINT = "https://api.mistral.ai"

src/mistralai/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def __init__(
3535
self.headers = headers or {}
3636

3737
@classmethod
38-
def from_response(
39-
cls, response: Response, message: Optional[str] = None
40-
) -> MistralAPIException:
38+
def from_response(cls, response: Response, message: Optional[str] = None) -> MistralAPIException:
4139
return cls(
4240
message=message or response.text,
4341
http_status=response.status_code,
@@ -47,8 +45,10 @@ def from_response(
4745
def __repr__(self) -> str:
4846
return f"{self.__class__.__name__}(message={str(self)}, http_status={self.http_status})"
4947

48+
5049
class MistralAPIStatusException(MistralAPIException):
5150
"""Returned when we receive a non-200 response from the API that we should retry"""
5251

52+
5353
class MistralConnectionException(MistralException):
5454
"""Returned when the SDK can not reach the API server for any reason"""

0 commit comments

Comments
 (0)