Skip to content

fix(runtime): redis connection leak #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions resumable_stream/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,12 @@ async def stream_generator():
message_handler_task.cancel()
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
raise
finally:
if message_handler_task:
message_handler_task.cancel()
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
await pubsub.aclose()


stream = stream_generator()
if not start:
Expand All @@ -515,6 +521,7 @@ async def stream_generator():
if message_handler_task:
message_handler_task.cancel()
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
await pubsub.aclose()
raise


Expand Down Expand Up @@ -560,14 +567,50 @@ async def stream_generator():
"""
try:
debug_log("STARTING STREAM", stream_id, listener_id)
start = asyncio.get_event_loop().time()
first_message_received = False

async def get_first_message():
"""Get the first actual message from pubsub"""
async for message in pubsub.listen():
if message["type"] == "message":
return message

# Race the first message against a timeout
get_message_task = asyncio.create_task(get_first_message())
timeout_task = asyncio.create_task(asyncio.sleep(1.0))

try:
done, pending = await asyncio.wait(
[get_message_task, timeout_task],
return_when=asyncio.FIRST_COMPLETED
)

# Cancel pending tasks
for task in pending:
task.cancel()

if timeout_task in done:
# Timeout occurred
raise TimeoutError("Timeout waiting for ack")

# We got the first message
first_message = get_message_task.result()
if first_message is None:
raise TimeoutError("Timeout waiting for ack")

debug_log("Received message", first_message["data"])

if first_message["data"] == DONE_MESSAGE:
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")
return

yield first_message["data"]
first_message_received = True

# Continue listening for remaining messages (no timeout needed)
async for message in pubsub.listen():
if message["type"] == "message":
debug_log("Received message", message["data"])
timeout_task.cancel()

if message["data"] == DONE_MESSAGE:
await pubsub.unsubscribe(
Expand All @@ -576,17 +619,23 @@ async def stream_generator():
return

yield message["data"]

except asyncio.CancelledError:
val = await ctx.redis.get(f"{ctx.key_prefix}:sentinel:{stream_id}")
if val == DONE_VALUE:
return
if asyncio.get_event_loop().time() - start > 1.0:
if not first_message_received:
raise TimeoutError("Timeout waiting for ack")
raise
finally:
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")

except Exception as e:
debug_log("Error in resume_stream", e)
raise
finally:
await pubsub.aclose()


# Start the stream and send the request
stream = stream_generator()
Expand All @@ -602,7 +651,7 @@ async def stream_generator():

return stream
except Exception:
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")
await pubsub.aclose()
raise


Expand Down Expand Up @@ -663,7 +712,7 @@ async def incr_or_done(publisher: Redis, key: str) -> Union[str, int]:
return await publisher.incr(key)
except Exception as reason:
error_string = str(reason)
if "ERR value is not an integer or out of range" in error_string:
if "value is not an integer or out of range" in error_string:
return DONE_VALUE
raise

Expand Down
149 changes: 149 additions & 0 deletions tests/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
create_resumable_stream_context,
ResumableStreamContext,
)
from resumable_stream.runtime import incr_or_done, DONE_VALUE
from typing import AsyncGenerator, List, Any


Expand All @@ -30,6 +31,115 @@ async def async_generator(items: List[str]) -> AsyncGenerator[str, None]:
yield item


@pytest.mark.asyncio
async def test_incr_or_done_new_key(redis: Redis) -> None:
"""Test incr_or_done with a new key that doesn't exist."""
key = "test-incr-new"

# Ensure key doesn't exist
await redis.delete(key)

result = await incr_or_done(redis, key)
assert result == 1

# Verify the key was actually set
value = await redis.get(key)
assert value == "1"

# Clean up
await redis.delete(key)


@pytest.mark.asyncio
async def test_incr_or_done_existing_integer(redis: Redis) -> None:
"""Test incr_or_done with an existing integer key."""
key = "test-incr-existing"

# Set initial value
await redis.set(key, "5")

result = await incr_or_done(redis, key)
assert result == 6

# Verify the key was incremented
value = await redis.get(key)
assert value == "6"

# Test incrementing again
result = await incr_or_done(redis, key)
assert result == 7

# Clean up
await redis.delete(key)


@pytest.mark.asyncio
async def test_incr_or_done_with_done_value(redis: Redis) -> None:
"""Test incr_or_done with a key containing DONE_VALUE."""
key = "test-incr-done"

# Set key to DONE_VALUE
await redis.set(key, DONE_VALUE)

result = await incr_or_done(redis, key)
assert result == DONE_VALUE

# Verify the key value is unchanged
value = await redis.get(key)
assert value == DONE_VALUE

# Clean up
await redis.delete(key)


@pytest.mark.asyncio
async def test_incr_or_done_with_non_integer_string(redis: Redis) -> None:
"""Test incr_or_done with a key containing a non-integer string."""
key = "test-incr-string"

# Set key to a non-integer string
await redis.set(key, "not-a-number")

result = await incr_or_done(redis, key)
assert result == DONE_VALUE

# Verify the key value is unchanged
value = await redis.get(key)
assert value == "not-a-number"

# Clean up
await redis.delete(key)


@pytest.mark.asyncio
async def test_incr_or_done_multiple_increments(redis: Redis) -> None:
"""Test multiple increments to verify the function works consistently."""
key = "test-incr-multiple"

# Clean start
await redis.delete(key)

# First increment (key doesn't exist)
result1 = await incr_or_done(redis, key)
assert result1 == 1

# Second increment
result2 = await incr_or_done(redis, key)
assert result2 == 2

# Third increment
result3 = await incr_or_done(redis, key)
assert result3 == 3

# Now set it to DONE and verify behavior changes
await redis.set(key, DONE_VALUE)
result4 = await incr_or_done(redis, key)
assert result4 == DONE_VALUE

# Clean up
await redis.delete(key)


@pytest.mark.asyncio
@pytest.mark.timeout(1)
async def test_create_new_stream(stream_context: ResumableStreamContext) -> None:
Expand Down Expand Up @@ -257,3 +367,42 @@ async def test_resume_existing_stream_with_start(
received_chunks.append(chunk)

assert "".join(received_chunks) == "".join(test_data)


@pytest.mark.asyncio
@pytest.mark.timeout(5)
async def test_timeout_and_connection_closure(stream_context: ResumableStreamContext, redis: Redis) -> None:
"""Test that pubsub connections are properly cleaned up when timeout occurs during stream resumption."""

stream_id = "test-timeout-stream"

# Set up a stream state that exists but has no active publisher
# This simulates a scenario where a stream was created but the publisher died
await redis.set(f"test-resumable-stream:rs:sentinel:{stream_id}", "2", ex=24*60*60)

# Try to resume the stream - this should timeout because no publisher is responding
# The internal timeout in resume_stream is 1 second
with pytest.raises(TimeoutError, match="Timeout waiting for ack"):
resumed_stream = await stream_context.resume_existing_stream(stream_id)
if resumed_stream:
# Try to consume the stream - this should trigger the timeout
chunks = []
async for chunk in resumed_stream:
chunks.append(chunk)

# After the timeout, verify that the Redis state is still intact
# (timeout shouldn't corrupt the stream state)
state = await redis.get(f"test-resumable-stream:rs:sentinel:{stream_id}")
assert state == "2" # Should still be "2", not "DONE"

# Verify that no pubsub channels are leaked by checking active channels
# This tests the resource cleanup aspect
pubsub_channels = await redis.execute_command("PUBSUB", "CHANNELS", "test-resumable-stream:rs:*")

# There should be no active channels for our test stream after timeout cleanup
if pubsub_channels:
timeout_related_channels = [ch for ch in pubsub_channels if stream_id in str(ch)]
assert len(timeout_related_channels) == 0, f"Found leaked channels: {timeout_related_channels}"

# Clean up
await redis.delete(f"test-resumable-stream:rs:sentinel:{stream_id}")