|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | from typing import Annotated, Optional
|
| 6 | +from unittest.mock import MagicMock, patch |
6 | 7 |
|
7 | 8 | import pytest
|
8 |
| -from pydantic import BaseModel, Field |
| 9 | +from httpx import ConnectError |
| 10 | +from langchain_core.messages.ai import AIMessageChunk |
| 11 | +from langchain_core.messages.human import HumanMessage |
| 12 | +from langchain_core.messages.tool import ToolCallChunk |
| 13 | +from langchain_core.tools import tool |
| 14 | +from ollama import ResponseError |
| 15 | +from pydantic import BaseModel, Field, ValidationError |
9 | 16 | from typing_extensions import TypedDict
|
10 | 17 |
|
11 | 18 | from langchain_ollama import ChatOllama
|
12 | 19 |
|
13 | 20 | DEFAULT_MODEL_NAME = "llama3.1"
|
14 | 21 |
|
15 | 22 |
|
| 23 | +@tool |
| 24 | +def get_current_weather(location: str) -> dict: |
| 25 | + """Gets the current weather in a given location.""" |
| 26 | + if "boston" in location.lower(): |
| 27 | + return {"temperature": "15°F", "conditions": "snow"} |
| 28 | + return {"temperature": "unknown", "conditions": "unknown"} |
| 29 | + |
| 30 | + |
| 31 | +@patch("langchain_ollama.chat_models.Client.list") |
| 32 | +def test_init_model_not_found(mock_list: MagicMock) -> None: |
| 33 | + """Test that a ValueError is raised when the model is not found.""" |
| 34 | + mock_list.side_effect = ValueError("Test model not found") |
| 35 | + with pytest.raises(ValueError) as excinfo: |
| 36 | + ChatOllama(model="non-existent-model", validate_model_on_init=True) |
| 37 | + assert "Test model not found" in str(excinfo.value) |
| 38 | + |
| 39 | + |
| 40 | +@patch("langchain_ollama.chat_models.Client.list") |
| 41 | +def test_init_connection_error(mock_list: MagicMock) -> None: |
| 42 | + """Test that a ValidationError is raised on connect failure during init.""" |
| 43 | + mock_list.side_effect = ConnectError("Test connection error") |
| 44 | + |
| 45 | + with pytest.raises(ValidationError) as excinfo: |
| 46 | + ChatOllama(model="any-model", validate_model_on_init=True) |
| 47 | + assert "Failed to connect to Ollama" in str(excinfo.value) |
| 48 | + |
| 49 | + |
| 50 | +@patch("langchain_ollama.chat_models.Client.list") |
| 51 | +def test_init_response_error(mock_list: MagicMock) -> None: |
| 52 | + """Test that a ResponseError is raised.""" |
| 53 | + mock_list.side_effect = ResponseError("Test response error") |
| 54 | + |
| 55 | + with pytest.raises(ValidationError) as excinfo: |
| 56 | + ChatOllama(model="any-model", validate_model_on_init=True) |
| 57 | + assert "Received an error from the Ollama API" in str(excinfo.value) |
| 58 | + |
| 59 | + |
16 | 60 | @pytest.mark.parametrize(("method"), [("function_calling"), ("json_schema")])
|
17 | 61 | def test_structured_output(method: str) -> None:
|
18 | 62 | """Test to verify structured output via tool calling and `format` parameter."""
|
@@ -98,3 +142,97 @@ class Data(BaseModel):
|
98 | 142 |
|
99 | 143 | for chunk in chat.stream(text):
|
100 | 144 | assert isinstance(chunk, Data)
|
| 145 | + |
| 146 | + |
| 147 | +@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)]) |
| 148 | +def test_tool_streaming(model: str) -> None: |
| 149 | + """Test that the model can stream tool calls.""" |
| 150 | + llm = ChatOllama(model=model) |
| 151 | + chat_model_with_tools = llm.bind_tools([get_current_weather]) |
| 152 | + |
| 153 | + prompt = [HumanMessage("What is the weather today in Boston?")] |
| 154 | + |
| 155 | + # Flags and collectors for validation |
| 156 | + tool_chunk_found = False |
| 157 | + final_tool_calls = [] |
| 158 | + collected_tool_chunks: list[ToolCallChunk] = [] |
| 159 | + |
| 160 | + # Stream the response and inspect the chunks |
| 161 | + for chunk in chat_model_with_tools.stream(prompt): |
| 162 | + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" |
| 163 | + |
| 164 | + if chunk.tool_call_chunks: |
| 165 | + tool_chunk_found = True |
| 166 | + collected_tool_chunks.extend(chunk.tool_call_chunks) |
| 167 | + |
| 168 | + if chunk.tool_calls: |
| 169 | + final_tool_calls.extend(chunk.tool_calls) |
| 170 | + |
| 171 | + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." |
| 172 | + assert len(final_tool_calls) == 1, ( |
| 173 | + f"Expected 1 final tool call, but got {len(final_tool_calls)}" |
| 174 | + ) |
| 175 | + |
| 176 | + final_tool_call = final_tool_calls[0] |
| 177 | + assert final_tool_call["name"] == "get_current_weather" |
| 178 | + assert final_tool_call["args"] == {"location": "Boston"} |
| 179 | + |
| 180 | + assert len(collected_tool_chunks) > 0 |
| 181 | + assert collected_tool_chunks[0]["name"] == "get_current_weather" |
| 182 | + |
| 183 | + # The ID should be consistent across chunks that have it |
| 184 | + tool_call_id = collected_tool_chunks[0].get("id") |
| 185 | + assert tool_call_id is not None |
| 186 | + assert all( |
| 187 | + chunk.get("id") == tool_call_id |
| 188 | + for chunk in collected_tool_chunks |
| 189 | + if chunk.get("id") |
| 190 | + ) |
| 191 | + assert final_tool_call["id"] == tool_call_id |
| 192 | + |
| 193 | + |
| 194 | +@pytest.mark.parametrize(("model"), [(DEFAULT_MODEL_NAME)]) |
| 195 | +async def test_tool_astreaming(model: str) -> None: |
| 196 | + """Test that the model can stream tool calls.""" |
| 197 | + llm = ChatOllama(model=model) |
| 198 | + chat_model_with_tools = llm.bind_tools([get_current_weather]) |
| 199 | + |
| 200 | + prompt = [HumanMessage("What is the weather today in Boston?")] |
| 201 | + |
| 202 | + # Flags and collectors for validation |
| 203 | + tool_chunk_found = False |
| 204 | + final_tool_calls = [] |
| 205 | + collected_tool_chunks: list[ToolCallChunk] = [] |
| 206 | + |
| 207 | + # Stream the response and inspect the chunks |
| 208 | + async for chunk in chat_model_with_tools.astream(prompt): |
| 209 | + assert isinstance(chunk, AIMessageChunk), "Expected AIMessageChunk type" |
| 210 | + |
| 211 | + if chunk.tool_call_chunks: |
| 212 | + tool_chunk_found = True |
| 213 | + collected_tool_chunks.extend(chunk.tool_call_chunks) |
| 214 | + |
| 215 | + if chunk.tool_calls: |
| 216 | + final_tool_calls.extend(chunk.tool_calls) |
| 217 | + |
| 218 | + assert tool_chunk_found, "Tool streaming did not produce any tool_call_chunks." |
| 219 | + assert len(final_tool_calls) == 1, ( |
| 220 | + f"Expected 1 final tool call, but got {len(final_tool_calls)}" |
| 221 | + ) |
| 222 | + |
| 223 | + final_tool_call = final_tool_calls[0] |
| 224 | + assert final_tool_call["name"] == "get_current_weather" |
| 225 | + assert final_tool_call["args"] == {"location": "Boston"} |
| 226 | + |
| 227 | + assert len(collected_tool_chunks) > 0 |
| 228 | + assert collected_tool_chunks[0]["name"] == "get_current_weather" |
| 229 | + |
| 230 | + # The ID should be consistent across chunks that have it |
| 231 | + tool_call_id = collected_tool_chunks[0].get("id") |
| 232 | + assert tool_call_id is not None |
| 233 | + assert all( |
| 234 | + chunk.get("id") == tool_call_id |
| 235 | + for chunk in collected_tool_chunks |
| 236 | + if chunk.get("id") |
| 237 | + ) |
| 238 | + assert final_tool_call["id"] == tool_call_id |
0 commit comments