Skip to content

Commit d33f72e

Browse files
committed
Mock async iterators properly
By default AsyncMock doesn't provide any proper mocking of async async iterators, and if we just assign a `AsyncMock` to stub calls like `ReceivePublicTradesStream`, what happens is that `async for x in ReceivePublicTradesStream()` is used instead. What Python does with this is something like: ```py stream = ReceivePublicTradesStream() it = stream.__aiter__() # Gets an anync iterator, but this call is sync while v := await it.__anext(): ... ``` Because `ReceivePublicTradesStream` is an `AsyncMock`, getting any attributes (like `__aiter__()`) results in returning a new `AsyncMock`. But since `__aiter__()` is called without an `await`, then we get the following exceptions/warnings: * TypeError("'async for' requires an object with __aiter__ method, got coroutine") * RuntimeWarning: coroutine 'AsyncMockMixin._execute_mock_call' was never awaited To fix this, we just crate a proper fake async iterator we can return to the async mock. Signed-off-by: Leandro Lucarella <[email protected]>
1 parent 1e0170f commit d33f72e

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/test_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@
3636
)
3737

3838

39+
class _FakeAsyncIterable:
40+
def __aiter__(self) -> Any:
41+
"""Iterate over the fake async iterable."""
42+
return self
43+
44+
async def __anext__(self) -> Any:
45+
"""Return the next item in the async iterable."""
46+
raise StopAsyncIteration
47+
48+
3949
@dataclass
4050
class SetupParams: # pylint: disable=too-many-instance-attributes
4151
"""Parameters for the setup of the test suite."""
@@ -133,6 +143,10 @@ def set_up_order_detail_response(
133143

134144
async def test_stream_gridpool_orders(set_up: SetupParams) -> None:
135145
"""Test the method streaming gridpool orders."""
146+
set_up.mock_stub.ReceiveGridpoolOrdersStream = AsyncMock(
147+
return_value=_FakeAsyncIterable()
148+
)
149+
136150
set_up.client.gridpool_orders_stream(set_up.gridpool_id)
137151
await asyncio.sleep(0)
138152

@@ -143,6 +157,10 @@ async def test_stream_gridpool_orders(set_up: SetupParams) -> None:
143157

144158
async def test_stream_gridpool_orders_with_optional_inputs(set_up: SetupParams) -> None:
145159
"""Test the method streaming gridpool orders with some fields to filter for."""
160+
set_up.mock_stub.ReceiveGridpoolOrdersStream = AsyncMock(
161+
return_value=_FakeAsyncIterable()
162+
)
163+
146164
# Fields to filter for
147165
order_states = [OrderState.ACTIVE]
148166

@@ -162,6 +180,10 @@ async def test_stream_gridpool_trades(
162180
set_up: SetupParams,
163181
) -> None:
164182
"""Test the method streaming gridpool trades."""
183+
set_up.mock_stub.ReceiveGridpoolTradesStream = AsyncMock(
184+
return_value=_FakeAsyncIterable()
185+
)
186+
165187
set_up.client.gridpool_trades_stream(
166188
gridpool_id=set_up.gridpool_id, market_side=set_up.side
167189
)
@@ -177,6 +199,10 @@ async def test_receive_public_trades(
177199
set_up: SetupParams,
178200
) -> None:
179201
"""Test the method receiving public trades."""
202+
set_up.mock_stub.ReceivePublicTradesStream = AsyncMock(
203+
return_value=_FakeAsyncIterable()
204+
)
205+
180206
# Fields to filter for
181207
trade_states = [TradeState.ACTIVE]
182208

0 commit comments

Comments
 (0)