Skip to content

Commit 8c79dec

Browse files
committed
Mock async iterators properly
By default AsyncMock doesn't provide any proper mocking of async async iterators, and if we just assign a `AsyncMock` to stub calls like `ReceivePublicTradesStream`, what happens is that `async for x in ReceivePublicTradesStream()` is used instead. What Python does with this is something like: ```py stream = ReceivePublicTradesStream() it = stream.__aiter__() # Gets an anync iterator, but this call is sync while v := await it.__anext(): ... ``` Because `ReceivePublicTradesStream` is an `AsyncMock`, getting any attributes (like `__aiter__()`) results in returning a new `AsyncMock`. But since `__aiter__()` is called without an `await`, then we get the following exceptions/warnings: * TypeError("'async for' requires an object with __aiter__ method, got coroutine") * RuntimeWarning: coroutine 'AsyncMockMixin._execute_mock_call' was never awaited To fix this, we just crate a proper fake async iterator we can return to the async mock. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent d236cd5 commit 8c79dec

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

tests/test_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from datetime import datetime, timedelta, timezone
88
from decimal import Decimal
9-
from unittest.mock import AsyncMock
9+
from unittest.mock import AsyncMock, Mock
1010

1111
import pytest
1212

@@ -36,6 +36,16 @@
3636
)
3737

3838

39+
class _FakeAsyncIterable:
40+
def __aiter__(self) -> Any:
41+
"""Iterate over the fake async iterable."""
42+
return self
43+
44+
async def __anext__(self) -> Any:
45+
"""Return the next item in the async iterable."""
46+
raise StopAsyncIteration
47+
48+
3949
@dataclass
4050
class SetupParams: # pylint: disable=too-many-instance-attributes
4151
"""Parameters for the setup of the test suite."""
@@ -125,6 +135,10 @@ def set_up_order_detail_response(
125135

126136
async def test_stream_gridpool_orders(set_up: SetupParams) -> None:
127137
"""Test the method streaming gridpool orders."""
138+
set_up.mock_stub.ReceiveGridpoolOrdersStream = Mock(
139+
return_value=_FakeAsyncIterable()
140+
)
141+
128142
set_up.client.gridpool_orders_stream(set_up.gridpool_id)
129143
await asyncio.sleep(0)
130144

@@ -135,6 +149,10 @@ async def test_stream_gridpool_orders(set_up: SetupParams) -> None:
135149

136150
async def test_stream_gridpool_orders_with_optional_inputs(set_up: SetupParams) -> None:
137151
"""Test the method streaming gridpool orders with some fields to filter for."""
152+
set_up.mock_stub.ReceiveGridpoolOrdersStream = Mock(
153+
return_value=_FakeAsyncIterable()
154+
)
155+
138156
# Fields to filter for
139157
order_states = [OrderState.ACTIVE]
140158

@@ -153,6 +171,10 @@ async def test_stream_gridpool_trades(
153171
set_up: SetupParams,
154172
) -> None:
155173
"""Test the method streaming gridpool trades."""
174+
set_up.mock_stub.ReceiveGridpoolTradesStream = Mock(
175+
return_value=_FakeAsyncIterable()
176+
)
177+
156178
set_up.client.gridpool_trades_stream(
157179
gridpool_id=set_up.gridpool_id, market_side=set_up.side
158180
)
@@ -168,6 +190,8 @@ async def test_receive_public_trades(
168190
set_up: SetupParams,
169191
) -> None:
170192
"""Test the method receiving public trades."""
193+
set_up.mock_stub.ReceivePublicTradesStream = Mock(return_value=_FakeAsyncIterable())
194+
171195
# Fields to filter for
172196
trade_states = [TradeState.ACTIVE]
173197

0 commit comments

Comments
 (0)