Skip to content

Commit 234d9ff

Browse files
authored
Merge pull request #1 from lilac/fix/connection-leak
fix(runtime): redis connection leak
2 parents 6fc86de + 8b85e39 commit 234d9ff

File tree

2 files changed

+204
-6
lines changed

2 files changed

+204
-6
lines changed

resumable_stream/runtime.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,12 @@ async def stream_generator():
500500
message_handler_task.cancel()
501501
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
502502
raise
503+
finally:
504+
if message_handler_task:
505+
message_handler_task.cancel()
506+
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
507+
await pubsub.aclose()
508+
503509

504510
stream = stream_generator()
505511
if not start:
@@ -515,6 +521,7 @@ async def stream_generator():
515521
if message_handler_task:
516522
message_handler_task.cancel()
517523
await pubsub.unsubscribe(f"{ctx.key_prefix}:request:{stream_id}")
524+
await pubsub.aclose()
518525
raise
519526

520527

@@ -560,14 +567,50 @@ async def stream_generator():
560567
"""
561568
try:
562569
debug_log("STARTING STREAM", stream_id, listener_id)
563-
start = asyncio.get_event_loop().time()
570+
first_message_received = False
571+
572+
async def get_first_message():
573+
"""Get the first actual message from pubsub"""
574+
async for message in pubsub.listen():
575+
if message["type"] == "message":
576+
return message
577+
578+
# Race the first message against a timeout
579+
get_message_task = asyncio.create_task(get_first_message())
564580
timeout_task = asyncio.create_task(asyncio.sleep(1.0))
565-
581+
566582
try:
583+
done, pending = await asyncio.wait(
584+
[get_message_task, timeout_task],
585+
return_when=asyncio.FIRST_COMPLETED
586+
)
587+
588+
# Cancel pending tasks
589+
for task in pending:
590+
task.cancel()
591+
592+
if timeout_task in done:
593+
# Timeout occurred
594+
raise TimeoutError("Timeout waiting for ack")
595+
596+
# We got the first message
597+
first_message = get_message_task.result()
598+
if first_message is None:
599+
raise TimeoutError("Timeout waiting for ack")
600+
601+
debug_log("Received message", first_message["data"])
602+
603+
if first_message["data"] == DONE_MESSAGE:
604+
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")
605+
return
606+
607+
yield first_message["data"]
608+
first_message_received = True
609+
610+
# Continue listening for remaining messages (no timeout needed)
567611
async for message in pubsub.listen():
568612
if message["type"] == "message":
569613
debug_log("Received message", message["data"])
570-
timeout_task.cancel()
571614

572615
if message["data"] == DONE_MESSAGE:
573616
await pubsub.unsubscribe(
@@ -576,17 +619,23 @@ async def stream_generator():
576619
return
577620

578621
yield message["data"]
622+
579623
except asyncio.CancelledError:
580624
val = await ctx.redis.get(f"{ctx.key_prefix}:sentinel:{stream_id}")
581625
if val == DONE_VALUE:
582626
return
583-
if asyncio.get_event_loop().time() - start > 1.0:
627+
if not first_message_received:
584628
raise TimeoutError("Timeout waiting for ack")
629+
raise
585630
finally:
586631
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")
632+
587633
except Exception as e:
588634
debug_log("Error in resume_stream", e)
589635
raise
636+
finally:
637+
await pubsub.aclose()
638+
590639

591640
# Start the stream and send the request
592641
stream = stream_generator()
@@ -602,7 +651,7 @@ async def stream_generator():
602651

603652
return stream
604653
except Exception:
605-
await pubsub.unsubscribe(f"{ctx.key_prefix}:chunk:{listener_id}")
654+
await pubsub.aclose()
606655
raise
607656

608657

@@ -663,7 +712,7 @@ async def incr_or_done(publisher: Redis, key: str) -> Union[str, int]:
663712
return await publisher.incr(key)
664713
except Exception as reason:
665714
error_string = str(reason)
666-
if "ERR value is not an integer or out of range" in error_string:
715+
if "value is not an integer or out of range" in error_string:
667716
return DONE_VALUE
668717
raise
669718

tests/test_runtime.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
create_resumable_stream_context,
77
ResumableStreamContext,
88
)
9+
from resumable_stream.runtime import incr_or_done, DONE_VALUE
910
from typing import AsyncGenerator, List, Any
1011

1112

@@ -30,6 +31,115 @@ async def async_generator(items: List[str]) -> AsyncGenerator[str, None]:
3031
yield item
3132

3233

34+
@pytest.mark.asyncio
35+
async def test_incr_or_done_new_key(redis: Redis) -> None:
36+
"""Test incr_or_done with a new key that doesn't exist."""
37+
key = "test-incr-new"
38+
39+
# Ensure key doesn't exist
40+
await redis.delete(key)
41+
42+
result = await incr_or_done(redis, key)
43+
assert result == 1
44+
45+
# Verify the key was actually set
46+
value = await redis.get(key)
47+
assert value == "1"
48+
49+
# Clean up
50+
await redis.delete(key)
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_incr_or_done_existing_integer(redis: Redis) -> None:
55+
"""Test incr_or_done with an existing integer key."""
56+
key = "test-incr-existing"
57+
58+
# Set initial value
59+
await redis.set(key, "5")
60+
61+
result = await incr_or_done(redis, key)
62+
assert result == 6
63+
64+
# Verify the key was incremented
65+
value = await redis.get(key)
66+
assert value == "6"
67+
68+
# Test incrementing again
69+
result = await incr_or_done(redis, key)
70+
assert result == 7
71+
72+
# Clean up
73+
await redis.delete(key)
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_incr_or_done_with_done_value(redis: Redis) -> None:
78+
"""Test incr_or_done with a key containing DONE_VALUE."""
79+
key = "test-incr-done"
80+
81+
# Set key to DONE_VALUE
82+
await redis.set(key, DONE_VALUE)
83+
84+
result = await incr_or_done(redis, key)
85+
assert result == DONE_VALUE
86+
87+
# Verify the key value is unchanged
88+
value = await redis.get(key)
89+
assert value == DONE_VALUE
90+
91+
# Clean up
92+
await redis.delete(key)
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_incr_or_done_with_non_integer_string(redis: Redis) -> None:
97+
"""Test incr_or_done with a key containing a non-integer string."""
98+
key = "test-incr-string"
99+
100+
# Set key to a non-integer string
101+
await redis.set(key, "not-a-number")
102+
103+
result = await incr_or_done(redis, key)
104+
assert result == DONE_VALUE
105+
106+
# Verify the key value is unchanged
107+
value = await redis.get(key)
108+
assert value == "not-a-number"
109+
110+
# Clean up
111+
await redis.delete(key)
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_incr_or_done_multiple_increments(redis: Redis) -> None:
116+
"""Test multiple increments to verify the function works consistently."""
117+
key = "test-incr-multiple"
118+
119+
# Clean start
120+
await redis.delete(key)
121+
122+
# First increment (key doesn't exist)
123+
result1 = await incr_or_done(redis, key)
124+
assert result1 == 1
125+
126+
# Second increment
127+
result2 = await incr_or_done(redis, key)
128+
assert result2 == 2
129+
130+
# Third increment
131+
result3 = await incr_or_done(redis, key)
132+
assert result3 == 3
133+
134+
# Now set it to DONE and verify behavior changes
135+
await redis.set(key, DONE_VALUE)
136+
result4 = await incr_or_done(redis, key)
137+
assert result4 == DONE_VALUE
138+
139+
# Clean up
140+
await redis.delete(key)
141+
142+
33143
@pytest.mark.asyncio
34144
@pytest.mark.timeout(1)
35145
async def test_create_new_stream(stream_context: ResumableStreamContext) -> None:
@@ -257,3 +367,42 @@ async def test_resume_existing_stream_with_start(
257367
received_chunks.append(chunk)
258368

259369
assert "".join(received_chunks) == "".join(test_data)
370+
371+
372+
@pytest.mark.asyncio
373+
@pytest.mark.timeout(5)
374+
async def test_timeout_and_connection_closure(stream_context: ResumableStreamContext, redis: Redis) -> None:
375+
"""Test that pubsub connections are properly cleaned up when timeout occurs during stream resumption."""
376+
377+
stream_id = "test-timeout-stream"
378+
379+
# Set up a stream state that exists but has no active publisher
380+
# This simulates a scenario where a stream was created but the publisher died
381+
await redis.set(f"test-resumable-stream:rs:sentinel:{stream_id}", "2", ex=24*60*60)
382+
383+
# Try to resume the stream - this should timeout because no publisher is responding
384+
# The internal timeout in resume_stream is 1 second
385+
with pytest.raises(TimeoutError, match="Timeout waiting for ack"):
386+
resumed_stream = await stream_context.resume_existing_stream(stream_id)
387+
if resumed_stream:
388+
# Try to consume the stream - this should trigger the timeout
389+
chunks = []
390+
async for chunk in resumed_stream:
391+
chunks.append(chunk)
392+
393+
# After the timeout, verify that the Redis state is still intact
394+
# (timeout shouldn't corrupt the stream state)
395+
state = await redis.get(f"test-resumable-stream:rs:sentinel:{stream_id}")
396+
assert state == "2" # Should still be "2", not "DONE"
397+
398+
# Verify that no pubsub channels are leaked by checking active channels
399+
# This tests the resource cleanup aspect
400+
pubsub_channels = await redis.execute_command("PUBSUB", "CHANNELS", "test-resumable-stream:rs:*")
401+
402+
# There should be no active channels for our test stream after timeout cleanup
403+
if pubsub_channels:
404+
timeout_related_channels = [ch for ch in pubsub_channels if stream_id in str(ch)]
405+
assert len(timeout_related_channels) == 0, f"Found leaked channels: {timeout_related_channels}"
406+
407+
# Clean up
408+
await redis.delete(f"test-resumable-stream:rs:sentinel:{stream_id}")

0 commit comments

Comments
 (0)