Skip to content

Commit 0279af6

Browse files
authored
langchain-mistralai[patch]: Add ruff bandit rules to linter, formatting (#31803)
- Add ruff bandit rules - Address a s101 error - Formatting
1 parent 425ee52 commit 0279af6

File tree

3 files changed

+31
-18
lines changed

3 files changed

+31
-18
lines changed

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def _convert_mistral_chat_message_to_message(
139139
_message: dict,
140140
) -> BaseMessage:
141141
role = _message["role"]
142-
assert role == "assistant", f"Expected role to be 'assistant', got {role}"
142+
if role != "assistant":
143+
raise ValueError(f"Expected role to be 'assistant', got {role}")
143144
content = cast(str, _message["content"])
144145

145146
additional_kwargs: dict = {}
@@ -398,7 +399,8 @@ class ChatMistralAI(BaseChatModel):
398399
max_tokens: Optional[int] = None
399400
top_p: float = 1
400401
"""Decode using nucleus sampling: consider the smallest set of tokens whose
401-
probability sum is at least top_p. Must be in the closed interval [0.0, 1.0]."""
402+
probability sum is at least ``top_p``. Must be in the closed interval
403+
``[0.0, 1.0]``."""
402404
random_seed: Optional[int] = None
403405
safe_mode: Optional[bool] = None
404406
streaming: bool = False

libs/partners/mistralai/langchain_mistralai/embeddings.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,21 +54,23 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
5454
Name of MistralAI model to use.
5555
5656
Key init args — client params:
57-
api_key: Optional[SecretStr]
58-
The API key for the MistralAI API. If not provided, it will be read from the
59-
environment variable `MISTRAL_API_KEY`.
60-
max_retries: int
61-
The number of times to retry a request if it fails.
62-
timeout: int
63-
The number of seconds to wait for a response before timing out.
64-
wait_time: int
65-
The number of seconds to wait before retrying a request in case of 429 error.
66-
max_concurrent_requests: int
67-
The maximum number of concurrent requests to make to the Mistral API.
57+
api_key: Optional[SecretStr]
58+
The API key for the MistralAI API. If not provided, it will be read from the
59+
environment variable ``MISTRAL_API_KEY``.
60+
max_retries: int
61+
The number of times to retry a request if it fails.
62+
timeout: int
63+
The number of seconds to wait for a response before timing out.
64+
wait_time: int
65+
The number of seconds to wait before retrying a request in case of 429
66+
error.
67+
max_concurrent_requests: int
68+
The maximum number of concurrent requests to make to the Mistral API.
6869
6970
See full list of supported init args and their descriptions in the params section.
7071
7172
Instantiate:
73+
7274
.. code-block:: python
7375
7476
from __module_name__ import MistralAIEmbeddings
@@ -80,6 +82,7 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
8082
)
8183
8284
Embed single text:
85+
8386
.. code-block:: python
8487
8588
input_text = "The meaning of life is 42"
@@ -91,9 +94,10 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
9194
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
9295
9396
Embed multiple text:
97+
9498
.. code-block:: python
9599
96-
input_texts = ["Document 1...", "Document 2..."]
100+
input_texts = ["Document 1...", "Document 2..."]
97101
vectors = embed.embed_documents(input_texts)
98102
print(len(vectors))
99103
# The first 3 coordinates for the first vector
@@ -105,10 +109,11 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
105109
[-0.024603435769677162, -0.007543657906353474, 0.0039630369283258915]
106110
107111
Async:
112+
108113
.. code-block:: python
109114
110115
vector = await embed.aembed_query(input_text)
111-
print(vector[:3])
116+
print(vector[:3])
112117
113118
# multiple:
114119
# await embed.aembed_documents(input_texts)
@@ -188,8 +193,8 @@ def validate_environment(self) -> Self:
188193
return self
189194

190195
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
191-
"""Split a list of texts into batches of less than 16k tokens
192-
for Mistral API."""
196+
"""Split a list of texts into batches of less than 16k tokens for Mistral
197+
API."""
193198
batch: list[str] = []
194199
batch_tokens = 0
195200

libs/partners/mistralai/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ disallow_untyped_defs = "True"
4848
target-version = "py39"
4949

5050
[tool.ruff.lint]
51-
select = ["E", "F", "I", "T201", "UP"]
51+
select = ["E", "F", "I", "T201", "UP", "S"]
5252
ignore = [ "UP007", ]
5353

5454
[tool.coverage.run]
@@ -61,3 +61,9 @@ markers = [
6161
"compile: mark placeholder test used to compile integration tests without running them",
6262
]
6363
asyncio_mode = "auto"
64+
65+
[tool.ruff.lint.extend-per-file-ignores]
66+
"tests/**/*.py" = [
67+
"S101", # Tests need assertions
68+
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
69+
]

0 commit comments

Comments
 (0)