|
1 | 1 | from unittest import mock |
2 | 2 |
|
| 3 | +try: |
| 4 | + from unittest.mock import AsyncMock |
| 5 | +except ImportError: |
| 6 | + |
| 7 | + class AsyncMock(mock.MagicMock): |
| 8 | + async def __call__(self, *args, **kwargs): |
| 9 | + return super(AsyncMock, self).__call__(*args, **kwargs) |
| 10 | + |
| 11 | + |
3 | 12 | import pytest |
4 | 13 | from anthropic import AsyncAnthropic, Anthropic, AnthropicError, AsyncStream, Stream |
5 | 14 | from anthropic.types import MessageDeltaUsage, TextDelta, Usage |
@@ -140,7 +149,7 @@ async def test_nonstreaming_create_message_async( |
140 | 149 | ) |
141 | 150 | events = capture_events() |
142 | 151 | client = AsyncAnthropic(api_key="z") |
143 | | - client.messages._post = mock.AsyncMock(return_value=EXAMPLE_MESSAGE) |
| 152 | + client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE) |
144 | 153 |
|
145 | 154 | messages = [ |
146 | 155 | { |
@@ -344,7 +353,7 @@ async def test_streaming_create_message_async( |
344 | 353 | send_default_pii=send_default_pii, |
345 | 354 | ) |
346 | 355 | events = capture_events() |
347 | | - client.messages._post = mock.AsyncMock(return_value=returned_stream) |
| 356 | + client.messages._post = AsyncMock(return_value=returned_stream) |
348 | 357 |
|
349 | 358 | messages = [ |
350 | 359 | { |
@@ -611,7 +620,7 @@ async def test_streaming_create_message_with_input_json_delta_async( |
611 | 620 | send_default_pii=send_default_pii, |
612 | 621 | ) |
613 | 622 | events = capture_events() |
614 | | - client.messages._post = mock.AsyncMock(return_value=returned_stream) |
| 623 | + client.messages._post = AsyncMock(return_value=returned_stream) |
615 | 624 |
|
616 | 625 | messages = [ |
617 | 626 | { |
@@ -683,7 +692,7 @@ async def test_exception_message_create_async(sentry_init, capture_events): |
683 | 692 | events = capture_events() |
684 | 693 |
|
685 | 694 | client = AsyncAnthropic(api_key="z") |
686 | | - client.messages._post = mock.AsyncMock( |
| 695 | + client.messages._post = AsyncMock( |
687 | 696 | side_effect=AnthropicError("API rate limit reached") |
688 | 697 | ) |
689 | 698 | with pytest.raises(AnthropicError): |
@@ -732,7 +741,7 @@ async def test_span_origin_async(sentry_init, capture_events): |
732 | 741 | events = capture_events() |
733 | 742 |
|
734 | 743 | client = AsyncAnthropic(api_key="z") |
735 | | - client.messages._post = mock.AsyncMock(return_value=EXAMPLE_MESSAGE) |
| 744 | + client.messages._post = AsyncMock(return_value=EXAMPLE_MESSAGE) |
736 | 745 |
|
737 | 746 | messages = [ |
738 | 747 | { |
|
0 commit comments