Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 128 additions & 34 deletions faststream/redis/subscriber/usecases/stream_subscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,52 @@ async def _consume(self, *args: Any, start_signal: "Event") -> None:
start_signal.set()
await super()._consume(*args, start_signal=start_signal)

async def _create_group(self, reset_counter: bool = False) -> None:
if reset_counter:
group_create_id = "0"
else:
group_create_id = "$" if self.last_id == ">" else self.last_id
try:
await self._client.xgroup_create(
name=self.stream_sub.name,
id=group_create_id,
groupname=self.stream_sub.group,
mkstream=True,
)
except ResponseError as e:
if "already exists" not in str(e):
raise

def _protect_read_from_group_removal(
self,
read_func: Callable[[], Awaitable[ReadResponse]],
stream: "StreamSub",
) -> Callable[[], Awaitable[ReadResponse|Any]]:
async def _read_from_group_removal() -> ReadResponse:
try:
return await read_func()
except ResponseError as e:
err_msg = str(e)
known_error: bool = False
if "NOGROUP" in err_msg:
# most likely redis was flushed, so we need to reset our group
await self._create_group(reset_counter=True)
# Important: reset our internal position too
stream.last_id = ">"
known_error = True
if (
"smaller than the first available entry" in err_msg
or "greater than the maximum id" in err_msg
):
# group was modified by third party and we need to reset our position to an existing id
stream.last_id = "$"
known_error = True
if known_error:
return await read_func()
raise

return _read_from_group_removal

@override
async def start(self) -> None:
client = self._client
Expand Down Expand Up @@ -113,9 +159,7 @@ async def start(self) -> None:

if stream.min_idle_time is None:

def read(
_: str,
) -> Awaitable[ReadResponse]:
def _xreadgroup_call() -> Awaitable[ReadResponse]:
return client.xreadgroup(
groupname=stream.group,
consumername=stream.consumer,
Expand All @@ -125,18 +169,35 @@ def read(
noack=stream.no_ack,
)

protected_read_func = self._protect_read_from_group_removal(
read_func=_xreadgroup_call,
stream=stream,
)

async def read(
_: str,
) -> ReadResponse:
return await protected_read_func()

else:

async def read(_: str) -> ReadResponse:
stream_message = await client.xautoclaim(
def _xautoclaim_call() -> Awaitable[Any]:
return client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)
stream_name = self.stream_sub.name.encode()

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_xautoclaim_call,
stream=stream,
)

async def read(_: str) -> ReadResponse:
stream_message = await protected_autoclaim()
(next_id, messages, _) = stream_message

# Update start_id for next call
Expand All @@ -146,6 +207,7 @@ async def read(_: str) -> ReadResponse:
await asyncio.sleep(stream.polling_interval / 1000) # ms to s
return ()

stream_name = self.stream_sub.name.encode()
return ((stream_name, messages),)

else:
Expand All @@ -161,6 +223,31 @@ def read(

await super().start(read)

async def _get_one_message(self, timeout: float) -> ReadResponse | None:
if self.stream_sub.group and self.stream_sub.consumer:

def _readgroup_call() -> Awaitable[ReadResponse]:
return self._client.xreadgroup(
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)

protected_read = self._protect_read_from_group_removal(
read_func=_readgroup_call,
stream=self.stream_sub,
)
stream_message = await protected_read()
else:
stream_message = await self._client.xread(
{self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
return stream_message

@override
async def get_one(
self,
Expand All @@ -172,26 +259,30 @@ async def get_one(
)
if self.stream_sub.group and self.stream_sub.consumer:
if self.min_idle_time is None:
stream_message = await self._client.xreadgroup(
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
stream_message = await self._get_one_message(timeout)
if not stream_message:
return None

((stream_name, ((message_id, raw_message),)),) = stream_message
else:
stream_message = await self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,

def _autoclaim_call() -> Awaitable[Any]:
return self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_autoclaim_call,
stream=self.stream_sub,
)
stream_message = await protected_autoclaim()
if not stream_message:
return None
(next_id, messages, _) = stream_message
# Update start_id for next call
self.autoclaim_start_id = next_id
Expand Down Expand Up @@ -246,26 +337,29 @@ async def __aiter__(self) -> AsyncIterator["RedisStreamMessage"]: # type: ignor
while True:
if self.stream_sub.group and self.stream_sub.consumer:
if self.min_idle_time is None:
stream_message = await self._client.xreadgroup(
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
streams={self.stream_sub.name: self.last_id},
block=math.ceil(timeout * 1000),
count=1,
)
stream_message = await self._get_one_message(timeout)
if not stream_message:
continue

((stream_name, ((message_id, raw_message),)),) = stream_message
else:
stream_message = await self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,

def _autoclaim_call() -> Awaitable[Any]:
return self._client.xautoclaim(
name=self.stream_sub.name,
groupname=self.stream_sub.group,
consumername=self.stream_sub.consumer,
min_idle_time=self.min_idle_time,
start_id=self.autoclaim_start_id,
count=1,
)

protected_autoclaim = self._protect_read_from_group_removal(
read_func=_autoclaim_call,
stream=self.stream_sub,
)
stream_message = await protected_autoclaim()

(next_id, messages, _) = stream_message
# Update start_id for next call
self.autoclaim_start_id = next_id
Expand Down
53 changes: 53 additions & 0 deletions tests/brokers/redis/test_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,59 @@ async def handler(msg: RedisStreamMessage) -> None:
f"Redis stream must be empty here, found {queue_len} messages"
)

async def test_consume_from_group(
self,
queue: str,
) -> None:
event = asyncio.Event()

consume_broker = self.get_broker(apply_types=True)

@consume_broker.subscriber(
stream=StreamSub(queue, group="group", consumer=queue),
)
async def handler(msg: RedisMessage) -> None:
event.set()

async with self.patch_broker(consume_broker) as br:
await br.start()
redis_client = br._connection
with (
patch.object(
redis_client, "xreadgroup", spy_decorator(redis_client.xreadgroup)
) as m_readgroup,
patch.object(
redis_client,
"xgroup_create",
spy_decorator(redis_client.xgroup_create),
) as m_group_create,
):
await asyncio.wait(
(
asyncio.create_task(br.publish("hello", stream=queue)),
asyncio.create_task(event.wait()),
),
timeout=3,
)
await asyncio.sleep(0.1)
m_readgroup.mock.assert_called_once()
assert event.is_set()
await redis_client.flushall()
event.clear()
await asyncio.sleep(0.1)
await asyncio.wait(
(
asyncio.create_task(br.publish("hello again", stream=queue)),
asyncio.create_task(event.wait()),
),
timeout=3,
)

await asyncio.sleep(0.1)
m_group_create.mock.assert_called_once()

assert event.is_set()

async def test_get_one(
self,
queue: str,
Expand Down
Loading
Loading