|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 | from typing import Any, Optional |
16 | | -from unittest.mock import MagicMock, Mock, patch |
| 16 | +from unittest.mock import AsyncMock, MagicMock, Mock, patch |
17 | 17 | from typing import List |
18 | 18 |
|
19 | 19 | import httpx |
20 | 20 | import pytest |
21 | 21 | from neo4j_graphrag.exceptions import LLMGenerationError |
22 | 22 | from neo4j_graphrag.llm import LLMResponse, MistralAILLM |
23 | 23 | from neo4j_graphrag.types import LLMMessage |
| 24 | +from neo4j_graphrag.utils.rate_limit import NoOpRateLimitHandler |
24 | 25 | from pydantic import BaseModel, ConfigDict |
25 | 26 |
|
26 | 27 |
|
@@ -439,3 +440,46 @@ class TestModel(BaseModel): |
439 | 440 | assert "MistralAILLM does not currently support structured output" in str( |
440 | 441 | exc_info.value |
441 | 442 | ) |
| 443 | + |
| 444 | + |
| 445 | +@patch("neo4j_graphrag.llm.mistralai_llm.SDKError", MockSDKError) |
| 446 | +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") |
| 447 | +def test_mistralai_invoke_v2_rate_limit_handler_called( |
| 448 | + mock_mistral: Mock, |
| 449 | +) -> None: |
| 450 | + """Test that the rate limit handler is invoked on the V2 (List[LLMMessage]) path.""" |
| 451 | + messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}] |
| 452 | + mock_mistral_instance = mock_mistral.return_value |
| 453 | + chat_response_mock = MagicMock() |
| 454 | + chat_response_mock.choices = [MagicMock(message=MagicMock(content="Hi there!"))] |
| 455 | + mock_mistral_instance.chat.complete.return_value = chat_response_mock |
| 456 | + |
| 457 | + spy_handler = MagicMock(wraps=NoOpRateLimitHandler()) |
| 458 | + llm = MistralAILLM(model_name="mistral-model", rate_limit_handler=spy_handler) |
| 459 | + response = llm.invoke(messages) |
| 460 | + |
| 461 | + assert response.content == "Hi there!" |
| 462 | + spy_handler.handle_sync.assert_called_once() |
| 463 | + |
| 464 | + |
| 465 | +@pytest.mark.asyncio |
| 466 | +@patch("neo4j_graphrag.llm.mistralai_llm.SDKError", MockSDKError) |
| 467 | +@patch("neo4j_graphrag.llm.mistralai_llm.Mistral") |
| 468 | +async def test_mistralai_ainvoke_v2_rate_limit_handler_called( |
| 469 | + mock_mistral: Mock, |
| 470 | +) -> None: |
| 471 | + """Test that the rate limit handler is invoked on the async V2 (List[LLMMessage]) path.""" |
| 472 | + messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}] |
| 473 | + mock_mistral_instance = mock_mistral.return_value |
| 474 | + chat_response_mock = MagicMock() |
| 475 | + chat_response_mock.choices = [MagicMock(message=MagicMock(content="Hi there!"))] |
| 476 | + mock_mistral_instance.chat.complete_async = AsyncMock( |
| 477 | + return_value=chat_response_mock |
| 478 | + ) |
| 479 | + |
| 480 | + spy_handler = MagicMock(wraps=NoOpRateLimitHandler()) |
| 481 | + llm = MistralAILLM(model_name="mistral-model", rate_limit_handler=spy_handler) |
| 482 | + response = await llm.ainvoke(messages) |
| 483 | + |
| 484 | + assert response.content == "Hi there!" |
| 485 | + spy_handler.handle_async.assert_called_once() |
0 commit comments