Skip to content

Commit 8ff62d8

Browse files
authored
tidyup/review (#64)
final bits
1 parent 181cab0 commit 8ff62d8

File tree

12 files changed

+263
-90
lines changed

12 files changed

+263
-90
lines changed

.coverage

-52 KB
Binary file not shown.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ temp.py
1111
.env
1212
.mypy_cache
1313
.tox
14+
.coverage
1415
release.sh
1516
CLAUDE.md

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,21 @@ If you wish to avoid excessive throttling or have multiple producers on a stream
305305
```python
306306
from kinesis import Consumer
307307

308+
# One-shot: consume until idle_timeout (default 2s) with no new records
308309
async with Consumer(stream_name="test") as consumer:
309310
async for item in consumer:
310311
print(item)
311-
# Consumer continues to wait for new messages after catching up
312+
313+
# Continuous: wrap in while True to keep consuming across idle gaps
314+
async with Consumer(stream_name="test") as consumer:
315+
while True:
316+
async for item in consumer:
317+
print(item)
312318
```
313319

320+
> **Note**: `async for` ends after `idle_timeout` seconds of queue inactivity (default 2.0s).
321+
> For continuous consumption, wrap the `async for` in a `while True` loop.
322+
314323

315324
Options:
316325

@@ -324,7 +333,7 @@ Options:
324333
| max_queue_size | 10000 | the fetch() task shard will block when queue is at max |
325334
| max_shard_consumers | None | Max number of shards to use. None = all |
326335
| record_limit | 10000 | Number of records to fetch with get_records() |
327-
| sleep_time_no_records | 2 | No of seconds to sleep when caught up |
336+
| sleep_time_no_records | 2 | Seconds to sleep per shard when no new records are returned by `get_records` |
328337
| iterator_type | TRIM_HORIZON | Default shard iterator type for new/unknown shards (ie start from start of stream). Alternatives are "LATEST" (ie end of stream), "AT_TIMESTAMP" (ie particular point in time, requires defining `timestamp` arg) |
329338
| shard_fetch_rate | 1 | No of fetches per second (max = 5). 1 is recommended as allows having multiple consumers without hitting the max limit. |
330339
| checkpointer | MemoryCheckPointer() | Checkpointer to use |
@@ -337,6 +346,7 @@ Options:
337346
| create_stream | False | Creates a Kinesis Stream based on the `stream_name` keyword argument. Note if stream already existing it will ignore |
338347
| create_stream_shards | 1 | Sets the amount of shard you want for your new stream. Note if stream already existing it will ignore |
339348
| describe_timeout | 60 | Timeout in seconds for waiting for stream to become ACTIVE during startup. Increase for slow backends (e.g. LocalStack) |
349+
| idle_timeout | 2.0 | Seconds to wait for new records before ending iteration. Controls how long `async for` blocks on an empty queue before raising `StopAsyncIteration` |
340350
| timestamp | None | Timestamp to start reading stream from. Used with iterator type "AT_TIMESTAMP"
341351

342352
## Shard Management

benchmark.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ async def test_producer(
251251
) as consumer:
252252

253253
# Ensure consumer is set up before producing
254-
await consumer.start_consumer(wait_iterations=0)
254+
consumer._start_fetch_task()
255255

256256
# Add small delay to ensure consumer is ready
257257
await asyncio.sleep(1)
@@ -349,7 +349,7 @@ async def run_benchmark(args):
349349

350350
if not args.dry_run:
351351
# Create the stream
352-
async with StreamManager(stream_name, args.shards) as stream:
352+
async with StreamManager(stream_name, args.shards):
353353

354354
for iteration in range(args.iterations):
355355
if args.iterations > 1:
@@ -512,7 +512,7 @@ def main():
512512
def cleanup_handler(signum=None, frame=None):
513513
"""Handle cleanup on exit"""
514514
try:
515-
loop = asyncio.get_running_loop()
515+
asyncio.get_running_loop()
516516
asyncio.create_task(cleanup_all_streams())
517517
except RuntimeError:
518518
# No running loop, create one

kinesis/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class Base:
1717
def __init__(
1818
self,
1919
stream_name: str,
20+
*,
2021
session: Optional[AioSession] = None,
2122
endpoint_url: Optional[str] = None,
2223
region_name: Optional[str] = None,

kinesis/cli/stream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
class _ClientHelper(Base):
2121
"""Minimal Base subclass for lightweight Kinesis API access (describe/list only)."""
2222

23-
def __init__(self, stream_name="", endpoint_url=None, region_name=None):
23+
def __init__(self, stream_name="", *, endpoint_url=None, region_name=None):
2424
super().__init__(
2525
stream_name=stream_name,
2626
endpoint_url=endpoint_url,

kinesis/consumer.py

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import logging
33
from asyncio import TimeoutError
4-
from asyncio.queues import QueueEmpty
54
from datetime import datetime, timezone
65
from typing import Any, AsyncIterator, Dict, Optional
76

@@ -36,6 +35,7 @@ class Consumer(Base):
3635
def __init__(
3736
self,
3837
stream_name: str,
38+
*,
3939
session: Optional[AioSession] = None,
4040
endpoint_url: Optional[str] = None,
4141
region_name: Optional[str] = None,
@@ -55,6 +55,7 @@ def __init__(
5555
create_stream: bool = False,
5656
create_stream_shards: int = 1,
5757
describe_timeout: int = 60,
58+
idle_timeout: float = 2.0,
5859
timestamp: Optional[datetime] = None,
5960
) -> None:
6061

@@ -77,6 +78,8 @@ def __init__(
7778

7879
self.sleep_time_no_records = sleep_time_no_records
7980

81+
self.idle_timeout = idle_timeout
82+
8083
self.max_shard_consumers = max_shard_consumers
8184

8285
self.record_limit = record_limit
@@ -669,26 +672,16 @@ def get_shard_status(self):
669672
"shard_details": shard_details,
670673
}
671674

672-
async def start_consumer(self, wait_iterations=10, wait_sleep=0.25):
673-
674-
# Start task to fetch periodically
675-
675+
def _start_fetch_task(self):
676676
self.fetch_task = asyncio.create_task(self._fetch())
677677

678-
# Wait a while until we have some results
679-
for i in range(0, wait_iterations):
680-
if self.fetch_task and self.queue.qsize() == 0:
681-
await asyncio.sleep(wait_sleep)
682-
683-
log.debug("start_consumer completed.. queue size={}".format(self.queue.qsize()))
684-
685678
async def __anext__(self):
686679

687680
if not self.shards:
688681
await self.get_conn()
689682

690683
if not self.fetch_task:
691-
await self.start_consumer()
684+
self._start_fetch_task()
692685

693686
# Raise exception from Fetch Task to main task otherwise raise exception inside
694687
# Fetch Task will fail silently
@@ -702,23 +695,21 @@ async def __anext__(self):
702695

703696
while True:
704697
try:
705-
item = self.queue.get_nowait()
706-
707-
if item and isinstance(item, dict) and "__CHECKPOINT__" in item:
708-
if self.checkpointer:
709-
await self.checkpointer.checkpoint(
710-
item["__CHECKPOINT__"]["ShardId"],
711-
item["__CHECKPOINT__"]["SequenceNumber"],
712-
)
713-
checkpoint_count += 1
714-
if checkpoint_count >= max_checkpoints:
715-
log.warning(f"Processed {max_checkpoints} checkpoints, stopping iteration")
716-
raise StopAsyncIteration
717-
continue
718-
719-
return item
698+
item = await asyncio.wait_for(self.queue.get(), timeout=self.idle_timeout)
699+
except asyncio.TimeoutError:
700+
log.debug(f"Queue idle for {self.idle_timeout}s, stopping iteration")
701+
raise StopAsyncIteration from None
702+
703+
if item and isinstance(item, dict) and "__CHECKPOINT__" in item:
704+
if self.checkpointer:
705+
await self.checkpointer.checkpoint(
706+
item["__CHECKPOINT__"]["ShardId"],
707+
item["__CHECKPOINT__"]["SequenceNumber"],
708+
)
709+
checkpoint_count += 1
710+
if checkpoint_count >= max_checkpoints:
711+
log.warning(f"Processed {max_checkpoints} checkpoints, stopping iteration")
712+
raise StopAsyncIteration
713+
continue
720714

721-
except QueueEmpty:
722-
log.debug("Queue empty..")
723-
await asyncio.sleep(self.sleep_time_no_records)
724-
raise StopAsyncIteration
715+
return item

kinesis/producer.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import math
44
import time
55
from asyncio.queues import QueueEmpty
6+
from collections import deque
67
from typing import Any, Awaitable, Callable, Optional
78

89
from aiobotocore.session import AioSession
@@ -22,6 +23,7 @@ class Producer(Base):
2223
def __init__(
2324
self,
2425
stream_name: str,
26+
*,
2527
session: Optional[AioSession] = None,
2628
endpoint_url: Optional[str] = None,
2729
region_name: Optional[str] = None,
@@ -81,14 +83,14 @@ def __init__(
8183

8284
self._flush_lock = asyncio.Lock()
8385
self._stop_event = asyncio.Event()
84-
self.flush_task = asyncio.create_task(self._flush())
86+
self.flush_task = None
8587
self.after_flush_fun = after_flush_fun
8688

8789
# keep track of these (used by unit test only)
8890
self.throughput_exceeded_count = 0
8991

90-
# overflow buffer
91-
self.overflow = []
92+
# overflow buffer (deque for O(1) popleft in get_batch FIFO)
93+
self.overflow = deque()
9294

9395
self.flush_total_records = 0
9496
self.flush_total_size = 0
@@ -134,8 +136,10 @@ async def put(self, data: Any, partition_key: Optional[str] = None) -> None:
134136

135137
# Raise exception from Flush Task to main task otherwise raise exception inside
136138
# Flush Task will fail silently
137-
if self.flush_task.done():
138-
raise self.flush_task.exception()
139+
if self.flush_task and self.flush_task.done():
140+
exc = self.flush_task.exception()
141+
if exc:
142+
raise exc
139143

140144
if not self.stream_status == self.ACTIVE:
141145
await self.get_conn()
@@ -149,27 +153,29 @@ async def put(self, data: Any, partition_key: Optional[str] = None) -> None:
149153
# Update queue size metric
150154
self.metrics.gauge(MetricType.PRODUCER_QUEUE_SIZE, self.queue.qsize(), {"stream_name": self.stream_name})
151155

156+
async def start(self):
157+
await super().start()
158+
# (Re)start flush infrastructure now that we have a live client.
159+
self._stop_event = asyncio.Event()
160+
self.flush_task = asyncio.create_task(self._flush())
161+
152162
async def close(self):
153163
log.debug(f"Closing Connection.. (stream status:{self.stream_status})")
154-
if not self.stream_status == self.RECONNECT:
155-
# Signal flush task to stop gracefully (don't cancel — let in-progress flush complete)
156-
self._stop_event.set()
157164

158-
if self.flush_task and not self.flush_task.done():
159-
try:
160-
done, _ = await asyncio.wait([self.flush_task], timeout=2.0)
161-
if not done:
162-
log.debug("Flush task did not finish in time, cancelling")
163-
self.flush_task.cancel()
164-
try:
165-
await self.flush_task
166-
except asyncio.CancelledError:
167-
pass
168-
except Exception as e:
169-
log.debug(f"Error awaiting cancelled flush task: {e}")
170-
except Exception as e:
171-
log.debug(f"Error during flush task cleanup: {e}")
165+
# Always stop background flush task, even during reconnect,
166+
# to avoid a dangling task referencing a closed client.
167+
self._stop_event.set()
168+
169+
if self.flush_task and not self.flush_task.done():
170+
# Wait for the flush task to finish — don't cancel it.
171+
# _stop_event ensures the loop exits after the current flush() completes,
172+
# letting any in-flight shielded put_records() finish rather than
173+
# re-queuing items that were already delivered (duplicate prevention).
174+
done, _ = await asyncio.wait([self.flush_task], timeout=10.0)
175+
if not done:
176+
log.warning("Flush task did not finish within 10s, proceeding with close")
172177

178+
if self.stream_status != self.RECONNECT:
173179
# Final flush to send any remaining queued items
174180
await self.flush()
175181

@@ -202,8 +208,13 @@ async def flush(self, _skip_if_locked=False):
202208
async with self._flush_lock:
203209

204210
if self.processor.has_items():
205-
for output in self.processor.get_items():
206-
await self.queue.put(output)
211+
outputs = list(self.processor.get_items())
212+
for output in outputs:
213+
try:
214+
self.queue.put_nowait(output)
215+
except asyncio.QueueFull:
216+
self.overflow.append(output)
217+
log.debug("Queue full during flush, spilled %d items to overflow", 1)
207218

208219
while True:
209220

@@ -311,7 +322,7 @@ async def get_batch(self):
311322
async with self.put_rate_throttle:
312323

313324
if self.overflow:
314-
item = self.overflow.pop()
325+
item = self.overflow.popleft()
315326

316327
else:
317328
try:
@@ -415,8 +426,8 @@ async def _push_kinesis(self, items):
415426
except ClientConnectionError:
416427
await self.get_conn()
417428
except asyncio.CancelledError:
418-
# In-flight put_records continues (shielded), but we can't get the result.
419-
# Re-queue items so the final flush in close() can retry them.
429+
# close() no longer cancels the flush task (it awaits completion),
430+
# but if something else cancels us, re-queue for at-least-once delivery.
420431
log.debug("put_records cancelled, re-queuing %d items to overflow", len(items))
421432
self.overflow.extend(items)
422433
raise

kinesis/testing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ class MockProducer:
188188
def __init__(
189189
self,
190190
stream_name: str,
191+
*,
191192
processor: Optional[Processor] = None,
192193
# Accepted for signature compatibility — ignored
193194
session=None,
@@ -270,6 +271,7 @@ class MockConsumer:
270271
def __init__(
271272
self,
272273
stream_name: str,
274+
*,
273275
processor: Optional[Processor] = None,
274276
checkpointer=None,
275277
iterator_type: str = "TRIM_HORIZON",

tests.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
NetstringAggregator,
1919
NewlineAggregator,
2020
OutputItem,
21-
SimpleAggregator,
2221
)
2322
from kinesis.processors import (
2423
JsonLineProcessor,
@@ -570,7 +569,7 @@ async def test_producer_put_exceed_batch_size(self):
570569

571570
async def test_producer_and_consumer(self):
572571

573-
async with Producer(stream_name=self.stream_name, endpoint_url=ENDPOINT_URL) as producer:
572+
async with Producer(stream_name=self.stream_name, endpoint_url=ENDPOINT_URL):
574573
pass
575574

576575
async with Consumer(stream_name=self.stream_name, endpoint_url=ENDPOINT_URL):
@@ -815,7 +814,7 @@ async def test_producer_and_consumer_consume_with_checkpointer_and_latest(self):
815814
) as consumer:
816815

817816
# Manually start
818-
await consumer.start_consumer()
817+
consumer._start_fetch_task()
819818

820819
await producer.put("test.B")
821820

@@ -961,7 +960,7 @@ async def test_consumer_checkpoint(self):
961960
) as consumer:
962961

963962
# Manually start
964-
await consumer.start_consumer()
963+
consumer._start_fetch_task()
965964

966965
await producer.put("test")
967966

@@ -1018,7 +1017,7 @@ async def test_producer_producer_limit(self):
10181017
iterator_type="LATEST",
10191018
) as consumer:
10201019

1021-
await consumer.start_consumer()
1020+
consumer._start_fetch_task()
10221021

10231022
# Wait a bit just to be sure iterator is gonna get late
10241023
await asyncio.sleep(3)

0 commit comments

Comments
 (0)