Skip to content

Commit 1f829aa

Browse files
authored
ollama[patch]: ruff fixes and rules (#31924)
* bump ruff deps * add more thorough ruff rules * fix said rules
1 parent 4d9eefe commit 1f829aa

File tree

11 files changed

+179
-108
lines changed

11 files changed

+179
-108
lines changed

libs/partners/ollama/Makefile

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,14 @@ lint_tests: PYTHON_FILES=tests
4141
lint_tests: MYPY_CACHE=.mypy_cache_test
4242

4343
lint lint_diff lint_package lint_tests:
44-
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff $(PYTHON_FILES)
44+
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check $(PYTHON_FILES)
4545
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES) --diff
4646
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && uv run --all-groups mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
4747

4848
format format_diff:
4949
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff format $(PYTHON_FILES)
50-
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff --fix $(PYTHON_FILES)
50+
[ "$(PYTHON_FILES)" = "" ] || uv run --all-groups ruff check --fix $(PYTHON_FILES)
51+
5152
spell_check:
5253
uv run --all-groups codespell --toml pyproject.toml
5354

libs/partners/ollama/langchain_ollama/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
__all__ = [
2020
"ChatOllama",
21-
"OllamaLLM",
2221
"OllamaEmbeddings",
22+
"OllamaLLM",
2323
"__version__",
2424
]

libs/partners/ollama/langchain_ollama/_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,21 @@ def validate_model(client: Client, model_name: str) -> None:
2020
if not any(
2121
model_name == m or m.startswith(f"{model_name}:") for m in model_names
2222
):
23-
raise ValueError(
23+
msg = (
2424
f"Model `{model_name}` not found in Ollama. Please pull the "
2525
f"model (using `ollama pull {model_name}`) or specify a valid "
2626
f"model name. Available local models: {', '.join(model_names)}"
2727
)
28+
raise ValueError(msg)
2829
except ConnectError as e:
29-
raise ValueError(
30+
msg = (
3031
"Connection to Ollama failed. Please make sure Ollama is running "
3132
f"and accessible at {client._client.base_url}. "
32-
) from e
33+
)
34+
raise ValueError(msg) from e
3335
except ResponseError as e:
34-
raise ValueError(
36+
msg = (
3537
"Received an error from the Ollama API. "
3638
"Please check your Ollama server logs."
37-
) from e
39+
)
40+
raise ValueError(msg) from e

libs/partners/ollama/langchain_ollama/chat_models.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Ollama chat models."""
22

3+
from __future__ import annotations
4+
35
import json
46
from collections.abc import AsyncIterator, Iterator, Mapping, Sequence
57
from operator import itemgetter
@@ -74,7 +76,9 @@ def _get_usage_metadata_from_generation_info(
7476

7577

7678
def _parse_json_string(
77-
json_string: str, raw_tool_call: dict[str, Any], skip: bool
79+
json_string: str,
80+
raw_tool_call: dict[str, Any],
81+
skip: bool, # noqa: FBT001
7882
) -> Any:
7983
"""Attempt to parse a JSON string for tool calling.
8084
@@ -148,16 +152,19 @@ def _get_tool_calls_from_response(
148152
) -> list[ToolCall]:
149153
"""Get tool calls from ollama response."""
150154
tool_calls = []
151-
if "message" in response:
152-
if raw_tool_calls := response["message"].get("tool_calls"):
153-
for tc in raw_tool_calls:
154-
tool_calls.append(
155-
tool_call(
156-
id=str(uuid4()),
157-
name=tc["function"]["name"],
158-
args=_parse_arguments_from_tool_call(tc) or {},
159-
)
155+
if "message" in response and (
156+
raw_tool_calls := response["message"].get("tool_calls")
157+
):
158+
tool_calls.extend(
159+
[
160+
tool_call(
161+
id=str(uuid4()),
162+
name=tc["function"]["name"],
163+
args=_parse_arguments_from_tool_call(tc) or {},
160164
)
165+
for tc in raw_tool_calls
166+
]
167+
)
161168
return tool_calls
162169

163170

@@ -178,14 +185,12 @@ def _get_image_from_data_content_block(block: dict) -> str:
178185
if block["type"] == "image":
179186
if block["source_type"] == "base64":
180187
return block["data"]
181-
else:
182-
error_message = "Image data only supported through in-line base64 format."
183-
raise ValueError(error_message)
184-
185-
else:
186-
error_message = f"Blocks of type {block['type']} not supported."
188+
error_message = "Image data only supported through in-line base64 format."
187189
raise ValueError(error_message)
188190

191+
error_message = f"Blocks of type {block['type']} not supported."
192+
raise ValueError(error_message)
193+
189194

190195
def _is_pydantic_class(obj: Any) -> bool:
191196
return isinstance(obj, type) and is_basemodel_subclass(obj)
@@ -459,7 +464,7 @@ class Multiply(BaseModel):
459464
"""Base url the model is hosted under."""
460465

461466
client_kwargs: Optional[dict] = {}
462-
"""Additional kwargs to pass to the httpx clients.
467+
"""Additional kwargs to pass to the httpx clients.
463468
These arguments are passed to both synchronous and async clients.
464469
Use sync_client_kwargs and async_client_kwargs to pass different arguments
465470
to synchronous and asynchronous clients.
@@ -496,7 +501,8 @@ def _chat_params(
496501
ollama_messages = self._convert_messages_to_ollama_messages(messages)
497502

498503
if self.stop is not None and stop is not None:
499-
raise ValueError("`stop` found in both the input and default params.")
504+
msg = "`stop` found in both the input and default params."
505+
raise ValueError(msg)
500506
if self.stop is not None:
501507
stop = self.stop
502508

@@ -584,7 +590,8 @@ def _convert_messages_to_ollama_messages(
584590
role = "tool"
585591
tool_call_id = message.tool_call_id
586592
else:
587-
raise ValueError("Received unsupported message type for Ollama.")
593+
msg = "Received unsupported message type for Ollama."
594+
raise ValueError(msg)
588595

589596
content = ""
590597
images = []
@@ -608,10 +615,11 @@ def _convert_messages_to_ollama_messages(
608615
):
609616
image_url = temp_image_url["url"]
610617
else:
611-
raise ValueError(
618+
msg = (
612619
"Only string image_url or dict with string 'url' "
613620
"inside content parts are supported."
614621
)
622+
raise ValueError(msg)
615623

616624
image_url_components = image_url.split(",")
617625
# Support data:image/jpeg;base64,<image> format
@@ -624,22 +632,24 @@ def _convert_messages_to_ollama_messages(
624632
image = _get_image_from_data_content_block(content_part)
625633
images.append(image)
626634
else:
627-
raise ValueError(
635+
msg = (
628636
"Unsupported message content type. "
629637
"Must either have type 'text' or type 'image_url' "
630638
"with a string 'image_url' field."
631639
)
632-
# Should convert to ollama.Message once role includes tool, and tool_call_id is in Message # noqa: E501
633-
msg: dict = {
640+
raise ValueError(msg)
641+
# Should convert to ollama.Message once role includes tool,
642+
# and tool_call_id is in Message
643+
msg_: dict = {
634644
"role": role,
635645
"content": content,
636646
"images": images,
637647
}
638648
if tool_calls:
639-
msg["tool_calls"] = tool_calls
649+
msg_["tool_calls"] = tool_calls
640650
if tool_call_id:
641-
msg["tool_call_id"] = tool_call_id
642-
ollama_messages.append(msg)
651+
msg_["tool_call_id"] = tool_call_id
652+
ollama_messages.append(msg_)
643653

644654
return ollama_messages
645655

@@ -677,7 +687,7 @@ def _chat_stream_with_aggregation(
677687
messages: list[BaseMessage],
678688
stop: Optional[list[str]] = None,
679689
run_manager: Optional[CallbackManagerForLLMRun] = None,
680-
verbose: bool = False,
690+
verbose: bool = False, # noqa: FBT001, FBT002
681691
**kwargs: Any,
682692
) -> ChatGenerationChunk:
683693
final_chunk = None
@@ -693,7 +703,8 @@ def _chat_stream_with_aggregation(
693703
verbose=verbose,
694704
)
695705
if final_chunk is None:
696-
raise ValueError("No data received from Ollama stream.")
706+
msg = "No data received from Ollama stream."
707+
raise ValueError(msg)
697708

698709
return final_chunk
699710

@@ -702,7 +713,7 @@ async def _achat_stream_with_aggregation(
702713
messages: list[BaseMessage],
703714
stop: Optional[list[str]] = None,
704715
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
705-
verbose: bool = False,
716+
verbose: bool = False, # noqa: FBT001, FBT002
706717
**kwargs: Any,
707718
) -> ChatGenerationChunk:
708719
final_chunk = None
@@ -718,7 +729,8 @@ async def _achat_stream_with_aggregation(
718729
verbose=verbose,
719730
)
720731
if final_chunk is None:
721-
raise ValueError("No data received from Ollama stream.")
732+
msg = "No data received from Ollama stream."
733+
raise ValueError(msg)
722734

723735
return final_chunk
724736

@@ -908,7 +920,7 @@ def bind_tools(
908920
self,
909921
tools: Sequence[Union[dict[str, Any], type, Callable, BaseTool]],
910922
*,
911-
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None,
923+
tool_choice: Optional[Union[dict, str, Literal["auto", "any"], bool]] = None, # noqa: PYI051
912924
**kwargs: Any,
913925
) -> Runnable[LanguageModelInput, BaseMessage]:
914926
"""Bind tool-like objects to this chat model.
@@ -923,7 +935,7 @@ def bind_tools(
923935
is currently ignored as it is not supported by Ollama.**
924936
kwargs: Any additional parameters are passed directly to
925937
``self.bind(**kwargs)``.
926-
""" # noqa: E501
938+
"""
927939
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
928940
return super().bind(tools=formatted_tools, **kwargs)
929941

@@ -1180,14 +1192,16 @@ class AnswerWithJustification(BaseModel):
11801192
""" # noqa: E501, D301
11811193
_ = kwargs.pop("strict", None)
11821194
if kwargs:
1183-
raise ValueError(f"Received unsupported arguments {kwargs}")
1195+
msg = f"Received unsupported arguments {kwargs}"
1196+
raise ValueError(msg)
11841197
is_pydantic_schema = _is_pydantic_class(schema)
11851198
if method == "function_calling":
11861199
if schema is None:
1187-
raise ValueError(
1200+
msg = (
11881201
"schema must be specified when method is not 'json_mode'. "
11891202
"Received None."
11901203
)
1204+
raise ValueError(msg)
11911205
formatted_tool = convert_to_openai_tool(schema)
11921206
tool_name = formatted_tool["function"]["name"]
11931207
llm = self.bind_tools(
@@ -1222,10 +1236,11 @@ class AnswerWithJustification(BaseModel):
12221236
)
12231237
elif method == "json_schema":
12241238
if schema is None:
1225-
raise ValueError(
1239+
msg = (
12261240
"schema must be specified when method is not 'json_mode'. "
12271241
"Received None."
12281242
)
1243+
raise ValueError(msg)
12291244
if is_pydantic_schema:
12301245
schema = cast(TypeBaseModel, schema)
12311246
if issubclass(schema, BaseModelV1):
@@ -1259,10 +1274,11 @@ class AnswerWithJustification(BaseModel):
12591274
)
12601275
output_parser = JsonOutputParser()
12611276
else:
1262-
raise ValueError(
1277+
msg = (
12631278
f"Unrecognized method argument. Expected one of 'function_calling', "
12641279
f"'json_schema', or 'json_mode'. Received: '{method}'"
12651280
)
1281+
raise ValueError(msg)
12661282

12671283
if include_raw:
12681284
parser_assign = RunnablePassthrough.assign(

libs/partners/ollama/langchain_ollama/embeddings.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Ollama embeddings models."""
22

3+
from __future__ import annotations
4+
35
from typing import Any, Optional
46

57
from langchain_core.embeddings import Embeddings
@@ -132,7 +134,7 @@ class OllamaEmbeddings(BaseModel, Embeddings):
132134
"""Base url the model is hosted under."""
133135

134136
client_kwargs: Optional[dict] = {}
135-
"""Additional kwargs to pass to the httpx clients.
137+
"""Additional kwargs to pass to the httpx clients.
136138
These arguments are passed to both synchronous and async clients.
137139
Use sync_client_kwargs and async_client_kwargs to pass different arguments
138140
to synchronous and asynchronous clients.
@@ -271,14 +273,14 @@ def _set_clients(self) -> Self:
271273
def embed_documents(self, texts: list[str]) -> list[list[float]]:
272274
"""Embed search docs."""
273275
if not self._client:
274-
raise ValueError(
276+
msg = (
275277
"Ollama client is not initialized. "
276278
"Please ensure Ollama is running and the model is loaded."
277279
)
278-
embedded_docs = self._client.embed(
280+
raise ValueError(msg)
281+
return self._client.embed(
279282
self.model, texts, options=self._default_params, keep_alive=self.keep_alive
280283
)["embeddings"]
281-
return embedded_docs
282284

283285
def embed_query(self, text: str) -> list[float]:
284286
"""Embed query text."""
@@ -287,16 +289,16 @@ def embed_query(self, text: str) -> list[float]:
287289
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
288290
"""Embed search docs."""
289291
if not self._async_client:
290-
raise ValueError(
292+
msg = (
291293
"Ollama client is not initialized. "
292294
"Please ensure Ollama is running and the model is loaded."
293295
)
294-
embedded_docs = (
296+
raise ValueError(msg)
297+
return (
295298
await self._async_client.embed(
296299
self.model, texts, keep_alive=self.keep_alive
297300
)
298301
)["embeddings"]
299-
return embedded_docs
300302

301303
async def aembed_query(self, text: str) -> list[float]:
302304
"""Embed query text."""

0 commit comments

Comments
 (0)