Skip to content

Commit cbb418b

Browse files
authored
mistralai[patch]: ruff fixes and rules (#31918)
* bump ruff deps * add more thorough ruff rules * fix said rules
1 parent ae210c1 commit cbb418b

File tree

10 files changed

+216
-145
lines changed

10 files changed

+216
-145
lines changed

libs/partners/mistralai/langchain_mistralai/chat_models.py

Lines changed: 81 additions & 68 deletions
Large diffs are not rendered by default.

libs/partners/mistralai/langchain_mistralai/embeddings.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,20 @@
1717
model_validator,
1818
)
1919
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
20-
from tokenizers import Tokenizer # type: ignore
20+
from tokenizers import Tokenizer # type: ignore[import]
2121
from typing_extensions import Self
2222

2323
logger = logging.getLogger(__name__)
2424

2525
MAX_TOKENS = 16_000
2626
"""A batching parameter for the Mistral API. This is NOT the maximum number of tokens
27-
accepted by the embedding model for each document/chunk, but rather the maximum number
27+
accepted by the embedding model for each document/chunk, but rather the maximum number
2828
of tokens that can be sent in a single request to the Mistral API (across multiple
2929
documents/chunks)"""
3030

3131

3232
class DummyTokenizer:
33-
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)"""
33+
"""Dummy tokenizer for when tokenizer cannot be accessed (e.g., via Huggingface)."""
3434

3535
@staticmethod
3636
def encode_batch(texts: list[str]) -> list[list[str]]:
@@ -126,9 +126,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
126126
# The type for client and async_client is ignored because the type is not
127127
# an Optional after the model is initialized and the model_validator
128128
# is run.
129-
client: httpx.Client = Field(default=None) # type: ignore # : :meta private:
129+
client: httpx.Client = Field(default=None) # type: ignore[assignment] # :meta private:
130130

131-
async_client: httpx.AsyncClient = Field( # type: ignore # : meta private:
131+
async_client: httpx.AsyncClient = Field( # type: ignore[assignment] # :meta private:
132132
default=None
133133
)
134134
mistral_api_key: SecretStr = Field(
@@ -153,7 +153,6 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
153153
@model_validator(mode="after")
154154
def validate_environment(self) -> Self:
155155
"""Validate configuration."""
156-
157156
api_key_str = self.mistral_api_key.get_secret_value()
158157
# todo: handle retries
159158
if not self.client:
@@ -187,14 +186,14 @@ def validate_environment(self) -> Self:
187186
"Could not download mistral tokenizer from Huggingface for "
188187
"calculating batch sizes. Set a Huggingface token via the "
189188
"HF_TOKEN environment variable to download the real tokenizer. "
190-
"Falling back to a dummy tokenizer that uses `len()`."
189+
"Falling back to a dummy tokenizer that uses `len()`.",
190+
stacklevel=2,
191191
)
192192
self.tokenizer = DummyTokenizer()
193193
return self
194194

195195
def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
196-
"""Split a list of texts into batches of less than 16k tokens for Mistral
197-
API."""
196+
"""Split list of texts into batches of less than 16k tokens for Mistral API."""
198197
batch: list[str] = []
199198
batch_tokens = 0
200199

@@ -224,6 +223,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
224223
225224
Returns:
226225
List of embeddings, one for each text.
226+
227227
"""
228228
try:
229229
batch_responses = []
@@ -238,16 +238,17 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
238238
def _embed_batch(batch: list[str]) -> Response:
239239
response = self.client.post(
240240
url="/embeddings",
241-
json=dict(
242-
model=self.model,
243-
input=batch,
244-
),
241+
json={
242+
"model": self.model,
243+
"input": batch,
244+
},
245245
)
246246
response.raise_for_status()
247247
return response
248248

249-
for batch in self._get_batches(texts):
250-
batch_responses.append(_embed_batch(batch))
249+
batch_responses = [
250+
_embed_batch(batch) for batch in self._get_batches(texts)
251+
]
251252
return [
252253
list(map(float, embedding_obj["embedding"]))
253254
for response in batch_responses
@@ -265,16 +266,17 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
265266
266267
Returns:
267268
List of embeddings, one for each text.
269+
268270
"""
269271
try:
270272
batch_responses = await asyncio.gather(
271273
*[
272274
self.async_client.post(
273275
url="/embeddings",
274-
json=dict(
275-
model=self.model,
276-
input=batch,
277-
),
276+
json={
277+
"model": self.model,
278+
"input": batch,
279+
},
278280
)
279281
for batch in self._get_batches(texts)
280282
]
@@ -296,6 +298,7 @@ def embed_query(self, text: str) -> list[float]:
296298
297299
Returns:
298300
Embedding for the text.
301+
299302
"""
300303
return self.embed_documents([text])[0]
301304

@@ -307,5 +310,6 @@ async def aembed_query(self, text: str) -> list[float]:
307310
308311
Returns:
309312
Embedding for the text.
313+
310314
"""
311315
return (await self.aembed_documents([text]))[0]

libs/partners/mistralai/pyproject.toml

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,62 @@ disallow_untyped_defs = "True"
4848
target-version = "py39"
4949

5050
[tool.ruff.lint]
51-
select = ["E", "F", "I", "T201", "UP", "S"]
52-
ignore = [ "UP007", ]
51+
select = [
52+
"A", # flake8-builtins
53+
"B", # flake8-bugbear
54+
"ASYNC", # flake8-async
55+
"C4", # flake8-comprehensions
56+
"COM", # flake8-commas
57+
"D", # pydocstyle
58+
"DOC", # pydoclint
59+
"E", # pycodestyle error
60+
"EM", # flake8-errmsg
61+
"F", # pyflakes
62+
"FA", # flake8-future-annotations
63+
"FBT", # flake8-boolean-trap
64+
"FLY", # flake8-flynt
65+
"I", # isort
66+
"ICN", # flake8-import-conventions
67+
"INT", # flake8-gettext
68+
"ISC", # isort-comprehensions
69+
"PGH", # pygrep-hooks
70+
"PIE", # flake8-pie
71+
"PERF", # flake8-perf
72+
"PYI", # flake8-pyi
73+
"Q", # flake8-quotes
74+
"RET", # flake8-return
75+
"RSE", # flake8-rst-docstrings
76+
"RUF", # ruff
77+
"S", # flake8-bandit
78+
"SLF", # flake8-self
79+
"SLOT", # flake8-slots
80+
"SIM", # flake8-simplify
81+
"T10", # flake8-debugger
82+
"T20", # flake8-print
83+
"TID", # flake8-tidy-imports
84+
"UP", # pyupgrade
85+
"W", # pycodestyle warning
86+
"YTT", # flake8-2020
87+
]
88+
ignore = [
89+
"D100", # pydocstyle: Missing docstring in public module
90+
"D101", # pydocstyle: Missing docstring in public class
91+
"D102", # pydocstyle: Missing docstring in public method
92+
"D103", # pydocstyle: Missing docstring in public function
93+
"D104", # pydocstyle: Missing docstring in public package
94+
"D105", # pydocstyle: Missing docstring in magic method
95+
"D107", # pydocstyle: Missing docstring in __init__
96+
"D203", # Messes with the formatter
97+
"D407", # pydocstyle: Missing-dashed-underline-after-section
98+
"COM812", # Messes with the formatter
99+
"ISC001", # Messes with the formatter
100+
"PERF203", # Rarely useful
101+
"S112", # Rarely useful
102+
"RUF012", # Doesn't play well with Pydantic
103+
"SLF001", # Private member access
104+
"UP007", # pyupgrade: non-pep604-annotation-union
105+
"UP045", # pyupgrade: non-pep604-annotation-optional
106+
]
53107

54108
[tool.coverage.run]
55109
omit = ["tests/*"]

libs/partners/mistralai/tests/integration_tests/test_chat_models.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test ChatMistral chat model."""
22

3+
from __future__ import annotations
4+
35
import json
46
import logging
57
import time
@@ -43,11 +45,12 @@ async def test_astream() -> None:
4345
if token.response_metadata:
4446
chunks_with_response_metadata += 1
4547
if chunks_with_token_counts != 1 or chunks_with_response_metadata != 1:
46-
raise AssertionError(
48+
msg = (
4749
"Expected exactly one chunk with token counts or response_metadata. "
4850
"AIMessageChunk aggregation adds / appends counts and metadata. Check that "
4951
"this is behaving properly."
5052
)
53+
raise AssertionError(msg)
5154
assert isinstance(full, AIMessageChunk)
5255
assert full.usage_metadata is not None
5356
assert full.usage_metadata["input_tokens"] > 0
@@ -61,7 +64,7 @@ async def test_astream() -> None:
6164

6265

6366
async def test_abatch() -> None:
64-
"""Test streaming tokens from ChatMistralAI"""
67+
"""Test streaming tokens from ChatMistralAI."""
6568
llm = ChatMistralAI()
6669

6770
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@@ -70,7 +73,7 @@ async def test_abatch() -> None:
7073

7174

7275
async def test_abatch_tags() -> None:
73-
"""Test batch tokens from ChatMistralAI"""
76+
"""Test batch tokens from ChatMistralAI."""
7477
llm = ChatMistralAI()
7578

7679
result = await llm.abatch(
@@ -81,7 +84,7 @@ async def test_abatch_tags() -> None:
8184

8285

8386
def test_batch() -> None:
84-
"""Test batch tokens from ChatMistralAI"""
87+
"""Test batch tokens from ChatMistralAI."""
8588
llm = ChatMistralAI()
8689

8790
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
@@ -90,7 +93,7 @@ def test_batch() -> None:
9093

9194

9295
async def test_ainvoke() -> None:
93-
"""Test invoke tokens from ChatMistralAI"""
96+
"""Test invoke tokens from ChatMistralAI."""
9497
llm = ChatMistralAI()
9598

9699
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
@@ -99,10 +102,10 @@ async def test_ainvoke() -> None:
99102

100103

101104
def test_invoke() -> None:
102-
"""Test invoke tokens from ChatMistralAI"""
105+
"""Test invoke tokens from ChatMistralAI."""
103106
llm = ChatMistralAI()
104107

105-
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
108+
result = llm.invoke("I'm Pickle Rick", config={"tags": ["foo"]})
106109
assert isinstance(result.content, str)
107110

108111

@@ -178,13 +181,11 @@ class Person(BaseModel):
178181

179182
structured_llm = llm.with_structured_output(Person)
180183
strm = structured_llm.stream("Erick, 27 years old")
181-
chunk_num = 0
182-
for chunk in strm:
184+
for chunk_num, chunk in enumerate(strm):
183185
assert chunk_num == 0, "should only have one chunk with model"
184186
assert isinstance(chunk, Person)
185187
assert chunk.name == "Erick"
186188
assert chunk.age == 27
187-
chunk_num += 1
188189

189190

190191
class Book(BaseModel):
@@ -201,7 +202,7 @@ def _check_parsed_result(result: Any, schema: Any) -> None:
201202
if schema == Book:
202203
assert isinstance(result, Book)
203204
else:
204-
assert all(key in ["name", "authors"] for key in result.keys())
205+
assert all(key in ["name", "authors"] for key in result)
205206

206207

207208
@pytest.mark.parametrize("schema", [Book, BookDict, Book.model_json_schema()])

libs/partners/mistralai/tests/integration_tests/test_compile.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@
44
@pytest.mark.compile
55
def test_placeholder() -> None:
66
"""Used for compiling integration tests without running any real tests."""
7-
pass

libs/partners/mistralai/tests/integration_tests/test_embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Test MistralAI Embedding"""
1+
"""Test MistralAI Embedding."""
22

33
from langchain_mistralai import MistralAIEmbeddings
44

libs/partners/mistralai/tests/integration_tests/test_standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Standard LangChain interface tests"""
1+
"""Standard LangChain interface tests."""
22

33
from langchain_core.language_models import BaseChatModel
44
from langchain_tests.integration_tests import ( # type: ignore[import-not-found]

libs/partners/mistralai/tests/unit_tests/test_chat_models.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -84,23 +84,23 @@ def test_mistralai_initialization_baseurl_env(env_var_name: str) -> None:
8484
[
8585
(
8686
SystemMessage(content="Hello"),
87-
dict(role="system", content="Hello"),
87+
{"role": "system", "content": "Hello"},
8888
),
8989
(
9090
HumanMessage(content="Hello"),
91-
dict(role="user", content="Hello"),
91+
{"role": "user", "content": "Hello"},
9292
),
9393
(
9494
AIMessage(content="Hello"),
95-
dict(role="assistant", content="Hello"),
95+
{"role": "assistant", "content": "Hello"},
9696
),
9797
(
9898
AIMessage(content="{", additional_kwargs={"prefix": True}),
99-
dict(role="assistant", content="{", prefix=True),
99+
{"role": "assistant", "content": "{", "prefix": True},
100100
),
101101
(
102102
ChatMessage(role="assistant", content="Hello"),
103-
dict(role="assistant", content="Hello"),
103+
{"role": "assistant", "content": "Hello"},
104104
),
105105
],
106106
)
@@ -112,17 +112,17 @@ def test_convert_message_to_mistral_chat_message(
112112

113113

114114
def _make_completion_response_from_token(token: str) -> dict:
115-
return dict(
116-
id="abc123",
117-
model="fake_model",
118-
choices=[
119-
dict(
120-
index=0,
121-
delta=dict(content=token),
122-
finish_reason=None,
123-
)
115+
return {
116+
"id": "abc123",
117+
"model": "fake_model",
118+
"choices": [
119+
{
120+
"index": 0,
121+
"delta": {"content": token},
122+
"finish_reason": None,
123+
}
124124
],
125-
)
125+
}
126126

127127

128128
def mock_chat_stream(*args: Any, **kwargs: Any) -> Generator:
@@ -275,8 +275,7 @@ def test_extra_kwargs() -> None:
275275

276276

277277
def test_retry_with_failure_then_success() -> None:
278-
"""Test that retry mechanism works correctly when
279-
first request fails and second succeeds."""
278+
"""Test retry mechanism works correctly when fiest request fails, second succeed."""
280279
# Create a real ChatMistralAI instance
281280
chat = ChatMistralAI(max_retries=3)
282281

@@ -289,7 +288,8 @@ def mock_post(*args: Any, **kwargs: Any) -> MagicMock:
289288
call_count += 1
290289

291290
if call_count == 1:
292-
raise httpx.RequestError("Connection error", request=MagicMock())
291+
msg = "Connection error"
292+
raise httpx.RequestError(msg, request=MagicMock())
293293

294294
mock_response = MagicMock()
295295
mock_response.status_code = 200

libs/partners/mistralai/tests/unit_tests/test_standard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Standard LangChain interface tests"""
1+
"""Standard LangChain interface tests."""
22

33
from langchain_core.language_models import BaseChatModel
44
from langchain_tests.unit_tests import ( # type: ignore[import-not-found]

0 commit comments

Comments
 (0)