Skip to content

Commit a9eda18

Browse files
authored
refactor(ollama): clean up tests (#33198)
1 parent a89c549 commit a9eda18

File tree

11 files changed

+382
-431
lines changed

11 files changed

+382
-431
lines changed

libs/partners/ollama/langchain_ollama/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,20 @@
1414
"""
1515

1616
from importlib import metadata
17+
from importlib.metadata import PackageNotFoundError
1718

1819
from langchain_ollama.chat_models import ChatOllama
1920
from langchain_ollama.embeddings import OllamaEmbeddings
2021
from langchain_ollama.llms import OllamaLLM
2122

23+
24+
def _raise_package_not_found_error() -> None:
25+
raise PackageNotFoundError
26+
27+
2228
try:
2329
if __package__ is None:
24-
raise metadata.PackageNotFoundError
30+
_raise_package_not_found_error()
2531
__version__ = metadata.version(__package__)
2632
except metadata.PackageNotFoundError:
2733
# Case where package metadata is not available.

libs/partners/ollama/langchain_ollama/chat_models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ def _parse_arguments_from_tool_call(
132132
Should be removed/changed if fixed upstream.
133133
134134
See https://github.com/ollama/ollama/issues/6155
135-
136135
"""
137136
if "function" not in raw_tool_call:
138137
return None
@@ -834,7 +833,7 @@ def _chat_stream_with_aggregation(
834833
messages: list[BaseMessage],
835834
stop: Optional[list[str]] = None,
836835
run_manager: Optional[CallbackManagerForLLMRun] = None,
837-
verbose: bool = False, # noqa: FBT001, FBT002
836+
verbose: bool = False, # noqa: FBT002
838837
**kwargs: Any,
839838
) -> ChatGenerationChunk:
840839
final_chunk = None
@@ -860,7 +859,7 @@ async def _achat_stream_with_aggregation(
860859
messages: list[BaseMessage],
861860
stop: Optional[list[str]] = None,
862861
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
863-
verbose: bool = False, # noqa: FBT001, FBT002
862+
verbose: bool = False, # noqa: FBT002
864863
**kwargs: Any,
865864
) -> ChatGenerationChunk:
866865
final_chunk = None

libs/partners/ollama/langchain_ollama/llms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ async def _astream_with_aggregation(
368368
prompt: str,
369369
stop: Optional[list[str]] = None,
370370
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
371-
verbose: bool = False, # noqa: FBT001, FBT002
371+
verbose: bool = False, # noqa: FBT002
372372
**kwargs: Any,
373373
) -> GenerationChunk:
374374
final_chunk = None
@@ -410,7 +410,7 @@ def _stream_with_aggregation(
410410
prompt: str,
411411
stop: Optional[list[str]] = None,
412412
run_manager: Optional[CallbackManagerForLLMRun] = None,
413-
verbose: bool = False, # noqa: FBT001, FBT002
413+
verbose: bool = False, # noqa: FBT002
414414
**kwargs: Any,
415415
) -> GenerationChunk:
416416
final_chunk = None

libs/partners/ollama/pyproject.toml

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,17 @@ ignore = [
6060
"SLF001", # Private member access
6161
"UP007", # pyupgrade: non-pep604-annotation-union
6262
"UP045", # pyupgrade: non-pep604-annotation-optional
63-
"PLR0912",
64-
"C901",
65-
"PLR0915",
63+
"FIX002", # TODOs
64+
"TD002", # TODO authors
65+
"TC002", # Incorrect type-checking block
66+
"TC003", # Incorrect type-checking block
67+
"PLR0912", # Too many branches
68+
"PLR0915", # Too many statements
69+
"C901", # Function too complex
70+
"FBT001", # Boolean function param
6671

67-
# TODO:
72+
# TODO
6873
"ANN401",
69-
"TC002",
70-
"TC003",
71-
"TRY301",
7274
]
7375
unfixable = ["B028"] # People should intentionally tune the stacklevel
7476

@@ -91,16 +93,11 @@ asyncio_mode = "auto"
9193

9294
[tool.ruff.lint.extend-per-file-ignores]
9395
"tests/**/*.py" = [
94-
"S101", # Tests need assertions
95-
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
96-
"PLR2004",
97-
98-
# TODO
99-
"ANN401",
100-
"ARG001",
101-
"PT011",
102-
"FIX",
103-
"TD",
96+
"S101", # Tests need assertions
97+
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
98+
"ARG001", # Unused function arguments in tests (e.g. kwargs)
99+
"PLR2004", # Magic value in comparisons
100+
"PT011", # `pytest.raises()` is too broad
104101
]
105102
"scripts/*.py" = [
106103
"INP001", # Not a package

libs/partners/ollama/tests/integration_tests/chat_models/test_chat_models.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,60 @@
33
from __future__ import annotations
44

55
from typing import Annotated, Optional
6+
from unittest.mock import MagicMock, patch
67

78
import pytest
8-
from pydantic import BaseModel, Field
9+
from httpx import ConnectError
10+
from langchain_core.messages.ai import AIMessageChunk
11+
from langchain_core.messages.human import HumanMessage
12+
from langchain_core.messages.tool import ToolCallChunk
13+
from langchain_core.tools import tool
14+
from ollama import ResponseError
15+
from pydantic import BaseModel, Field, ValidationError
916
from typing_extensions import TypedDict
1017

1118
from langchain_ollama import ChatOllama
1219

1320
DEFAULT_MODEL_NAME = "llama3.1"
1421

1522

23+
@tool
24+
def get_current_weather(location: str) -> dict:
25+
"""Gets the current weather in a given location."""
26+
if "boston" in location.lower():
27+
return {"temperature": "15°F", "conditions": "snow"}
28+
return {"temperature": "unknown", "conditions": "unknown"}
29+
30+
31+
@patch("langchain_ollama.chat_models.Client.list")
32+
def test_init_model_not_found(mock_list: MagicMock) -> None:
33+
"""Test that a ValueError is raised when the model is not found."""
34+
mock_list.side_effect = ValueError("Test model not found")
35+
with pytest.raises(ValueError) as excinfo:
36+
ChatOllama(model="non-existent-model", validate_model_on_init=True)
37+
assert "Test model not found" in str(excinfo.value)
38+
39+
40+
@patch("langchain_ollama.chat_models.Client.list")
41+
def test_init_connection_error(mock_list: MagicMock) -> None:
42+
"""Test that a ValidationError is raised on connect failure during init."""
43+
mock_list.side_effect = ConnectError("Test connection error")
44+
45+
with pytest.raises(ValidationError) as excinfo:
46+
ChatOllama(model="any-model", validate_model_on_init=True)
47+
assert "Failed to connect to Ollama" in str(excinfo.value)
48+
49+
50+
@patch("langchain_ollama.chat_models.Client.list")
51+
def test_init_response_error(mock_list: MagicMock) -> None:
52+
"""Test that a ResponseError is raised."""
53+
mock_list.side_effect = ResponseError("Test response error")
54+
55+
with pytest.raises(ValidationError) as excinfo:
56+
ChatOllama(model="any-model", validate_model_on_init=True)
57+
assert "Received an error from the Ollama API" in str(excinfo.value)
58+
59+
1660
@pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
1761
def test_structured_output(method: str) -> None:
1862
"""Test to verify structured output via tool calling and `format` parameter."""
@@ -98,3 +142,97 @@ class Data(BaseModel):
98142

99143
for chunk in chat.stream(text):
100144
assert isinstance(chunk, Data)
145+
146+
147+
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
148+
def test_tool_streaming(model: str) -> None:
149+
"""Test that the model can stream tool calls."""
150+
llm = ChatOllama(model=model)
151+
chat_model_with_tools = llm.bind_tools([get_current_weather])
152+
153+
prompt = [HumanMessage("What is the weather today in Boston?")]
154+
155+
# Flags and collectors for validation
156+
tool_chunk_found = False
157+
final_tool_calls = []
158+
collected_tool_chunks: list[ToolCallChunk] = []
159+
160+
# Stream the response and inspect the chunks
161+
for chunk in chat_model_with_tools.stream(prompt):
162+
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
163+
164+
if chunk.tool_call_chunks:
165+
tool_chunk_found = True
166+
collected_tool_chunks.extend(chunk.tool_call_chunks)
167+
168+
if chunk.tool_calls:
169+
final_tool_calls.extend(chunk.tool_calls)
170+
171+
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
172+
assert len(final_tool_calls) == 1, (
173+
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
174+
)
175+
176+
final_tool_call = final_tool_calls[0]
177+
assert final_tool_call["name"] == "get_current_weather"
178+
assert final_tool_call["args"] == {"location": "Boston"}
179+
180+
assert len(collected_tool_chunks) > 0
181+
assert collected_tool_chunks[0]["name"] == "get_current_weather"
182+
183+
# The ID should be consistent across chunks that have it
184+
tool_call_id = collected_tool_chunks[0].get("id")
185+
assert tool_call_id is not None
186+
assert all(
187+
chunk.get("id") == tool_call_id
188+
for chunk in collected_tool_chunks
189+
if chunk.get("id")
190+
)
191+
assert final_tool_call["id"] == tool_call_id
192+
193+
194+
@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)])
195+
async def test_tool_astreaming(model: str) -> None:
196+
"""Test that the model can stream tool calls."""
197+
llm = ChatOllama(model=model)
198+
chat_model_with_tools = llm.bind_tools([get_current_weather])
199+
200+
prompt = [HumanMessage("What is the weather today in Boston?")]
201+
202+
# Flags and collectors for validation
203+
tool_chunk_found = False
204+
final_tool_calls = []
205+
collected_tool_chunks: list[ToolCallChunk] = []
206+
207+
# Stream the response and inspect the chunks
208+
async for chunk in chat_model_with_tools.astream(prompt):
209+
assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type"
210+
211+
if chunk.tool_call_chunks:
212+
tool_chunk_found = True
213+
collected_tool_chunks.extend(chunk.tool_call_chunks)
214+
215+
if chunk.tool_calls:
216+
final_tool_calls.extend(chunk.tool_calls)
217+
218+
assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks."
219+
assert len(final_tool_calls) == 1, (
220+
f"Expected 1 final tool call, but got {len(final_tool_calls)}"
221+
)
222+
223+
final_tool_call = final_tool_calls[0]
224+
assert final_tool_call["name"] == "get_current_weather"
225+
assert final_tool_call["args"] == {"location": "Boston"}
226+
227+
assert len(collected_tool_chunks) > 0
228+
assert collected_tool_chunks[0]["name"] == "get_current_weather"
229+
230+
# The ID should be consistent across chunks that have it
231+
tool_call_id = collected_tool_chunks[0].get("id")
232+
assert tool_call_id is not None
233+
assert all(
234+
chunk.get("id") == tool_call_id
235+
for chunk in collected_tool_chunks
236+
if chunk.get("id")
237+
)
238+
assert final_tool_call["id"] == tool_call_id

0 commit comments

Comments
 (0)