diff --git a/resumable_stream/runtime.py b/resumable_stream/runtime.py index 32b8380..d0736cd 100644 --- a/resumable_stream/runtime.py +++ b/resumable_stream/runtime.py @@ -500,6 +500,12 @@ async def stream_generator(): message_handler_task.cancel() await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}") raise + finally: + if message_handler_task: + message_handler_task.cancel() + await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}") + await pubsub.aclose() + stream = stream_generator() if not start: @@ -515,6 +521,7 @@ async def stream_generator(): if message_handler_task: message_handler_task.cancel() await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}") + await pubsub.aclose() raise @@ -560,14 +567,50 @@ async def stream_generator(): """ try: debug_log("STARTING STREAM", stream_id, listener_id) - start = asyncio.get_event_loop().time() + first_message_received = False + + async def get_first_message(): + """Get the first actual message from pubsub""" + async for message in pubsub.listen(): + if message["type"] == "message": + return message + + # Race the first message against a timeout + get_message_task = asyncio.create_task(get_first_message()) timeout_task = asyncio.create_task(asyncio.sleep(1.0)) - + try: + done, pending = await asyncio.wait( + [get_message_task, timeout_task], + return_when=asyncio.FIRST_COMPLETED + ) + + # Cancel pending tasks + for task in pending: + task.cancel() + + if timeout_task in done: + # Timeout occurred + raise TimeoutError("Timeout waiting for ack") + + # We got the first message + first_message = get_message_task.result() + if first_message is None: + raise TimeoutError("Timeout waiting for ack") + + debug_log("Received message", first_message["data"]) + + if first_message["data"] == DONE_MESSAGE: + await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}") + return + + yield first_message["data"] + first_message_received = True + + # Continue listening for remaining messages (no timeout needed) async for message in pubsub.listen(): if message["type"] == "message": debug_log("Received message", message["data"]) - timeout_task.cancel() if message["data"] == DONE_MESSAGE: await pubsub.unsubscribe( @@ -576,17 +619,23 @@ async def stream_generator(): return yield message["data"] + except asyncio.CancelledError: val = await ctx.redis.get(f"{ctx.key_prefix}:sentinel:{stream_id}") if val == DONE_VALUE: return - if asyncio.get_event_loop().time() - start > 1.0: + if not first_message_received: raise TimeoutError("Timeout waiting for ack") + raise finally: await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}") + except Exception as e: debug_log("Error in resume_stream", e) raise + finally: + await pubsub.aclose() + # Start the stream and send the request stream = stream_generator() @@ -602,7 +651,7 @@ async def stream_generator(): return stream except Exception: - await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}") + await pubsub.aclose() raise @@ -663,7 +712,7 @@ async def incr_or_done(publisher: Redis, key: str) -> Union[str, int]: return await publisher.incr(key) except Exception as reason: error_string = str(reason) - if "ERR value is not an integer or out of range" in error_string: + if "value is not an integer or out of range" in error_string: return DONE_VALUE raise diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 0e578e4..3de73e7 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -6,6 +6,7 @@ create_resumable_stream_context, ResumableStreamContext, ) +from resumable_stream.runtime import incr_or_done, DONE_VALUE from typing import AsyncGenerator, List, Any @@ -30,6 +31,115 @@ async def async_generator(items: List[str]) -> AsyncGenerator[str, None]: yield item +@pytest.mark.asyncio +async def test_incr_or_done_new_key(redis: Redis) -> None: + """Test incr_or_done with a new key that doesn't exist.""" + key = "test-incr-new" + + # Ensure key doesn't exist + await redis.delete(key) + + result = await incr_or_done(redis, key) + assert result == 1 + + # Verify the key was actually set + value = await redis.get(key) + assert value == "1" + + # Clean up + await redis.delete(key) + + +@pytest.mark.asyncio +async def test_incr_or_done_existing_integer(redis: Redis) -> None: + """Test incr_or_done with an existing integer key.""" + key = "test-incr-existing" + + # Set initial value + await redis.set(key, "5") + + result = await incr_or_done(redis, key) + assert result == 6 + + # Verify the key was incremented + value = await redis.get(key) + assert value == "6" + + # Test incrementing again + result = await incr_or_done(redis, key) + assert result == 7 + + # Clean up + await redis.delete(key) + + +@pytest.mark.asyncio +async def test_incr_or_done_with_done_value(redis: Redis) -> None: + """Test incr_or_done with a key containing DONE_VALUE.""" + key = "test-incr-done" + + # Set key to DONE_VALUE + await redis.set(key, DONE_VALUE) + + result = await incr_or_done(redis, key) + assert result == DONE_VALUE + + # Verify the key value is unchanged + value = await redis.get(key) + assert value == DONE_VALUE + + # Clean up + await redis.delete(key) + + +@pytest.mark.asyncio +async def test_incr_or_done_with_non_integer_string(redis: Redis) -> None: + """Test incr_or_done with a key containing a non-integer string.""" + key = "test-incr-string" + + # Set key to a non-integer string + await redis.set(key, "not-a-number") + + result = await incr_or_done(redis, key) + assert result == DONE_VALUE + + # Verify the key value is unchanged + value = await redis.get(key) + assert value == "not-a-number" + + # Clean up + await redis.delete(key) + + +@pytest.mark.asyncio +async def test_incr_or_done_multiple_increments(redis: Redis) -> None: + """Test multiple increments to verify the function works consistently.""" + key = "test-incr-multiple" + + # Clean start + await redis.delete(key) + + # First increment (key doesn't exist) + result1 = await incr_or_done(redis, key) + assert result1 == 1 + + # Second increment + result2 = await incr_or_done(redis, key) + assert result2 == 2 + + # Third increment + result3 = await incr_or_done(redis, key) + assert result3 == 3 + + # Now set it to DONE and verify behavior changes + await redis.set(key, DONE_VALUE) + result4 = await incr_or_done(redis, key) + assert result4 == DONE_VALUE + + # Clean up + await redis.delete(key) + + @pytest.mark.asyncio @pytest.mark.timeout(1) async def test_create_new_stream(stream_context: ResumableStreamContext) -> None: @@ -257,3 +367,42 @@ async def test_resume_existing_stream_with_start( received_chunks.append(chunk) assert "".join(received_chunks) == "".join(test_data) + + +@pytest.mark.asyncio +@pytest.mark.timeout(5) +async def test_timeout_and_connection_closure(stream_context: ResumableStreamContext, redis: Redis) -> None: + """Test that pubsub connections are properly cleaned up when timeout occurs during stream resumption.""" + + stream_id = "test-timeout-stream" + + # Set up a stream state that exists but has no active publisher + # This simulates a scenario where a stream was created but the publisher died + await redis.set(f"test-resumable-stream:rs:sentinel:{stream_id}", "2", ex=24*60*60) + + # Try to resume the stream - this should timeout because no publisher is responding + # The internal timeout in resume_stream is 1 second + with pytest.raises(TimeoutError, match="Timeout waiting for ack"): + resumed_stream = await stream_context.resume_existing_stream(stream_id) + if resumed_stream: + # Try to consume the stream - this should trigger the timeout + chunks = [] + async for chunk in resumed_stream: + chunks.append(chunk) + + # After the timeout, verify that the Redis state is still intact + # (timeout shouldn't corrupt the stream state) + state = await redis.get(f"test-resumable-stream:rs:sentinel:{stream_id}") + assert state == "2" # Should still be "2", not "DONE" + + # Verify that no pubsub channels are leaked by checking active channels + # This tests the resource cleanup aspect + pubsub_channels = await redis.execute_command("PUBSUB", "CHANNELS", "test-resumable-stream:rs:*") + + # There should be no active channels for our test stream after timeout cleanup + if pubsub_channels: + timeout_related_channels = [ch for ch in pubsub_channels if stream_id in str(ch)] + assert len(timeout_related_channels) == 0, f"Found leaked channels: {timeout_related_channels}" + + # Clean up + await redis.delete(f"test-resumable-stream:rs:sentinel:{stream_id}")