|
1 | 1 | import asyncio |
2 | 2 | from copy import copy |
3 | | -from unittest.mock import Mock |
| 3 | +from unittest.mock import Mock, patch |
4 | 4 |
|
5 | 5 | import pytest |
6 | 6 | from mode import label |
@@ -236,6 +236,30 @@ async def test_stream_filter_acks_filtered_out_messages(app, event_loop): |
236 | 236 | assert len(app.consumer.unacked) == 0 |
237 | 237 |
|
238 | 238 |
|
| 239 | +@pytest.mark.asyncio |
| 240 | +async def test_acks_filtered_out_messages_when_using_take(app, event_loop): |
| 241 | + """ |
| 242 | + Test the filter function acknowledges the filtered out messages when using take(). |
| 243 | + """ |
| 244 | + initial_values = [1000, 999, 3000, 99, 5000, 3, 9999] |
| 245 | + expected_values = [v for v in initial_values if v > 1000] |
| 246 | + original_function = app.create_event |
| 247 | + # using patch to intercept message objects, to check if they are acked later |
| 248 | + with patch("faust.app.base.App.create_event") as create_event_mock: |
| 249 | + create_event_mock.side_effect = original_function |
| 250 | + async with new_stream(app) as stream: |
| 251 | + for value in initial_values: |
| 252 | + await stream.channel.send(value=value) |
| 253 | + async for values in stream.filter(lambda x: x > 1000).take( |
| 254 | + len(expected_values), within=5 |
| 255 | + ): |
| 256 | + assert values == expected_values |
| 257 | + break |
| 258 | + messages = [call[0][3] for call in create_event_mock.call_args_list] |
| 259 | + acked = [m.acked for m in messages if m.acked] |
| 260 | + assert len(acked) == len(initial_values) |
| 261 | + |
| 262 | + |
239 | 263 | @pytest.mark.asyncio |
240 | 264 | async def test_events(app): |
241 | 265 | async with new_stream(app) as stream: |
|
0 commit comments