Skip to content

Commit 469ed2e

Browse files
authored
Use rate limit handler on v2 invoke methods (#495)
1 parent 1402bc5 commit 469ed2e

File tree

5 files changed

+93
-1
lines changed

5 files changed

+93
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## Next
44

5+
### Fixed
6+
7+
- Fixed a bug where the rate limit handler was not being called on the `VertexAILLM` and `MistralAILLM` `__invoke_v2` and `__ainvoke_v2` methods.
8+
59
### Added
610

711
- `NodeType` and `RelationshipType` now reject labels and types that start or end with double underscores (`__`), e.g. `__Person__`. This convention is reserved for internal Neo4j GraphRAG labels. A `ValidationError` is raised on construction.

src/neo4j_graphrag/llm/mistralai_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __invoke_v1(
200200
except SDKError as e:
201201
raise LLMGenerationError(e)
202202

203+
@rate_limit_handler_decorator
203204
def __invoke_v2(
204205
self,
205206
input: List[LLMMessage],
@@ -277,6 +278,7 @@ async def __ainvoke_v1(
277278
except SDKError as e:
278279
raise LLMGenerationError(e)
279280

281+
@async_rate_limit_handler_decorator
280282
async def __ainvoke_v2(
281283
self,
282284
input: List[LLMMessage],

src/neo4j_graphrag/llm/vertexai_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def __invoke_v1(
286286
except ResponseValidationError as e:
287287
raise LLMGenerationError("Error calling VertexAILLM") from e
288288

289+
@rate_limit_handler_decorator
289290
def __invoke_v2(
290291
self,
291292
input: List[LLMMessage],
@@ -348,6 +349,7 @@ async def __ainvoke_v1(
348349
except ResponseValidationError as e:
349350
raise LLMGenerationError("Error calling VertexAILLM") from e
350351

352+
@async_rate_limit_handler_decorator
351353
async def __ainvoke_v2(
352354
self,
353355
input: list[LLMMessage],

tests/unit/llm/test_mistralai_llm.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from typing import Any, Optional
16-
from unittest.mock import MagicMock, Mock, patch
16+
from unittest.mock import AsyncMock, MagicMock, Mock, patch
1717
from typing import List
1818

1919
import httpx
2020
import pytest
2121
from neo4j_graphrag.exceptions import LLMGenerationError
2222
from neo4j_graphrag.llm import LLMResponse, MistralAILLM
2323
from neo4j_graphrag.types import LLMMessage
24+
from neo4j_graphrag.utils.rate_limit import NoOpRateLimitHandler
2425
from pydantic import BaseModel, ConfigDict
2526

2627

@@ -439,3 +440,46 @@ class TestModel(BaseModel):
439440
assert "MistralAILLM does not currently support structured output" in str(
440441
exc_info.value
441442
)
443+
444+
445+
@patch("neo4j_graphrag.llm.mistralai_llm.SDKError", MockSDKError)
446+
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
447+
def test_mistralai_invoke_v2_rate_limit_handler_called(
448+
mock_mistral: Mock,
449+
) -> None:
450+
"""Test that the rate limit handler is invoked on the V2 (List[LLMMessage]) path."""
451+
messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}]
452+
mock_mistral_instance = mock_mistral.return_value
453+
chat_response_mock = MagicMock()
454+
chat_response_mock.choices = [MagicMock(message=MagicMock(content="Hi there!"))]
455+
mock_mistral_instance.chat.complete.return_value = chat_response_mock
456+
457+
spy_handler = MagicMock(wraps=NoOpRateLimitHandler())
458+
llm = MistralAILLM(model_name="mistral-model", rate_limit_handler=spy_handler)
459+
response = llm.invoke(messages)
460+
461+
assert response.content == "Hi there!"
462+
spy_handler.handle_sync.assert_called_once()
463+
464+
465+
@pytest.mark.asyncio
466+
@patch("neo4j_graphrag.llm.mistralai_llm.SDKError", MockSDKError)
467+
@patch("neo4j_graphrag.llm.mistralai_llm.Mistral")
468+
async def test_mistralai_ainvoke_v2_rate_limit_handler_called(
469+
mock_mistral: Mock,
470+
) -> None:
471+
"""Test that the rate limit handler is invoked on the async V2 (List[LLMMessage]) path."""
472+
messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}]
473+
mock_mistral_instance = mock_mistral.return_value
474+
chat_response_mock = MagicMock()
475+
chat_response_mock.choices = [MagicMock(message=MagicMock(content="Hi there!"))]
476+
mock_mistral_instance.chat.complete_async = AsyncMock(
477+
return_value=chat_response_mock
478+
)
479+
480+
spy_handler = MagicMock(wraps=NoOpRateLimitHandler())
481+
llm = MistralAILLM(model_name="mistral-model", rate_limit_handler=spy_handler)
482+
response = await llm.ainvoke(messages)
483+
484+
assert response.content == "Hi there!"
485+
spy_handler.handle_async.assert_called_once()

tests/unit/llm/test_vertexai_llm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
3030
from neo4j_graphrag.tool import Tool
3131
from neo4j_graphrag.types import LLMMessage
32+
from neo4j_graphrag.utils.rate_limit import NoOpRateLimitHandler
3233

3334
from pydantic import BaseModel, ConfigDict
3435

@@ -598,3 +599,42 @@ async def test_vertexai_ainvoke_v2_with_json_schema_response_format(
598599
# Verify generation_config has response_schema
599600
call_args = mock_model.generate_content_async.call_args.kwargs
600601
assert "generation_config" in call_args
602+
603+
604+
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
605+
def test_vertexai_invoke_v2_rate_limit_handler_called(
606+
GenerativeModelMock: MagicMock,
607+
) -> None:
608+
"""Test that the rate limit handler is invoked on the V2 (List[LLMMessage]) path."""
609+
messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}]
610+
mock_response = Mock()
611+
mock_response.text = "Hi there!"
612+
mock_model = GenerativeModelMock.return_value
613+
mock_model.generate_content.return_value = mock_response
614+
615+
spy_handler = MagicMock(wraps=NoOpRateLimitHandler())
616+
llm = VertexAILLM(model_name="gemini-1.5-flash-001", rate_limit_handler=spy_handler)
617+
response = llm.invoke(messages)
618+
619+
assert response.content == "Hi there!"
620+
spy_handler.handle_sync.assert_called_once()
621+
622+
623+
@pytest.mark.asyncio
624+
@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
625+
async def test_vertexai_ainvoke_v2_rate_limit_handler_called(
626+
GenerativeModelMock: MagicMock,
627+
) -> None:
628+
"""Test that the rate limit handler is invoked on the async V2 (List[LLMMessage]) path."""
629+
messages: List[LLMMessage] = [{"role": "user", "content": "Hello"}]
630+
mock_response = AsyncMock()
631+
mock_response.text = "Hi there!"
632+
mock_model = GenerativeModelMock.return_value
633+
mock_model.generate_content_async = AsyncMock(return_value=mock_response)
634+
635+
spy_handler = MagicMock(wraps=NoOpRateLimitHandler())
636+
llm = VertexAILLM(model_name="gemini-1.5-flash-001", rate_limit_handler=spy_handler)
637+
response = await llm.ainvoke(messages)
638+
639+
assert response.content == "Hi there!"
640+
spy_handler.handle_async.assert_called_once()

0 commit comments

Comments
 (0)