Skip to content

Commit f741fd9

Browse files
Ack filtered messages when using .take() (#488)
Co-authored-by: Artem Ilin <a.ilin@arammeem.com>
1 parent 99ab4f6 commit f741fd9

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

faust/_cython/streams.pyx

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ cdef class StreamIterator:
109109
object consumer
110110
consumer = self.consumer
111111
last_stream_to_ack = False
112-
# if do_ack and event is not None:
113-
if event is not None and (do_ack or event.value is self._skipped_value):
112+
if do_ack and event is not None:
114113
message = event.message
115114
if not message.acked:
116115
refcount = message.refcount

faust/streams.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,7 @@ async def _c_aiter(self) -> AsyncIterator[T_co]: # pragma: no cover
10651065
yield value
10661066
finally:
10671067
event, self.current_event = self.current_event, None
1068-
it.after(event, do_ack, sensor_state)
1068+
it.after(event, do_ack or value is skipped_value, sensor_state)
10691069
except StopAsyncIteration:
10701070
# We are not allowed to propagate StopAsyncIteration in __aiter__
10711071
# (if we do, it'll be converted to RuntimeError by CPython).

tests/functional/test_streams.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
from copy import copy
3-
from unittest.mock import Mock
3+
from unittest.mock import Mock, patch
44

55
import pytest
66
from mode import label
@@ -236,6 +236,30 @@ async def test_stream_filter_acks_filtered_out_messages(app, event_loop):
236236
assert len(app.consumer.unacked) == 0
237237

238238

239+
@pytest.mark.asyncio
240+
async def test_acks_filtered_out_messages_when_using_take(app, event_loop):
241+
"""
242+
Test the filter function acknowledges the filtered out messages when using take().
243+
"""
244+
initial_values = [1000, 999, 3000, 99, 5000, 3, 9999]
245+
expected_values = [v for v in initial_values if v > 1000]
246+
original_function = app.create_event
247+
# using patch to intercept message objects, to check if they are acked later
248+
with patch("faust.app.base.App.create_event") as create_event_mock:
249+
create_event_mock.side_effect = original_function
250+
async with new_stream(app) as stream:
251+
for value in initial_values:
252+
await stream.channel.send(value=value)
253+
async for values in stream.filter(lambda x: x > 1000).take(
254+
len(expected_values), within=5
255+
):
256+
assert values == expected_values
257+
break
258+
messages = [call[0][3] for call in create_event_mock.call_args_list]
259+
acked = [m.acked for m in messages if m.acked]
260+
assert len(acked) == len(initial_values)
261+
262+
239263
@pytest.mark.asyncio
240264
async def test_events(app):
241265
async with new_stream(app) as stream:

0 commit comments

Comments
 (0)