diff --git a/README.md b/README.md index 45c6d7f..9e54cbb 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,5 @@ performed during each such call. be renamed into `put_nowait`. - `get_multi` and `put_multi` methods, allowing getting and putting multiple items from queue with one call -- method for periodical requeueing of not acknowledged tasks - keeping track of times a task was requeued, dropping too old tasks or tasks with too many retries. diff --git a/aioredisqueue/exceptions.py b/aioredisqueue/exceptions.py index 1dd16ad..4c96294 100644 --- a/aioredisqueue/exceptions.py +++ b/aioredisqueue/exceptions.py @@ -4,3 +4,7 @@ class Full(Exception): class Empty(Exception): pass + + +class Stopped(Exception): + pass \ No newline at end of file diff --git a/aioredisqueue/load_scripts.py b/aioredisqueue/load_scripts.py index 3007045..899924c 100644 --- a/aioredisqueue/load_scripts.py +++ b/aioredisqueue/load_scripts.py @@ -4,7 +4,7 @@ __all__ = ('load_put', 'load_ack', 'load_fail', 'load_requeue', - 'load_get_nowait', 'load_get_nowait_l') + 'load_get_nowait', 'load_get_nowait_l', 'load_check_requeue_timestamp') def _load(name): @@ -39,3 +39,7 @@ def load_get_nowait(): def load_get_nowait_l(): return _load('get_nowait_l') + + +def load_check_requeue_timestamp(): + return _load('check_requeue_timestamp') diff --git a/aioredisqueue/lua/check_requeue_timestamp.lua b/aioredisqueue/lua/check_requeue_timestamp.lua new file mode 100644 index 0000000..c8aa3a3 --- /dev/null +++ b/aioredisqueue/lua/check_requeue_timestamp.lua @@ -0,0 +1,11 @@ +-- KEYS[1] last requeue timestamp +-- ARGV[1] current timestamp +-- ARGV[2] requeue interval + +local last_requeue = tonumber(redis.call('get', KEYS[1])) +if last_requeue ~= nil and tonumber(ARGV[1]) - last_requeue <= tonumber(ARGV[2]) then + return {'error'} +else + redis.call('set', KEYS[1], ARGV[1]) + return {'ok'} +end diff --git a/aioredisqueue/lua/requeue.lua b/aioredisqueue/lua/requeue.lua index 610089c..a4902cd 100644 --- a/aioredisqueue/lua/requeue.lua +++ b/aioredisqueue/lua/requeue.lua @@ -12,7 +12,7 @@ local kvs = redis.call('HGETALL', KEYS[1]) for i = 1, #kvs, 2 do job_id = kvs[i] v = tonumber(kvs[i + 1]) - if v < min_val then + if min_val == -1 or v < min_val then r = redis.call('hdel', KEYS[1], job_id) table.insert(results_hdel, r) if r then @@ -26,4 +26,4 @@ for i = 1, #kvs, 2 do end end -return {results_hdel, results_lpush, results_keys} +return {'ok', results_hdel, results_lpush, results_keys} diff --git a/aioredisqueue/queue.py b/aioredisqueue/queue.py index 6f63f6f..1821c8c 100644 --- a/aioredisqueue/queue.py +++ b/aioredisqueue/queue.py @@ -1,4 +1,5 @@ import asyncio +import timeit from . import task, load_scripts, exceptions @@ -10,8 +11,10 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *, fetching_fifo_key=None, payloads_hash_key=None, ack_hash_key=None, + last_requeue_key=None, task_class=task.Task, lua_sha=None, + requeue_interval=10000, loop=None): if main_queue_key is None: @@ -26,6 +29,9 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *, if ack_hash_key is None: ack_hash_key = key_prefix + 'ack' + if last_requeue_key is None: + last_requeue_key = key_prefix + 'last_requeue' + if loop is None: loop = asyncio.get_event_loop() @@ -34,6 +40,7 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *, 'fifo': fetching_fifo_key, 'payload': payloads_hash_key, 'ack': ack_hash_key, + 'last_requeue': last_requeue_key, } self._redis = redis @@ -41,6 +48,14 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *, self._task_class = task_class self._lua_sha = lua_sha if lua_sha is not None else {} self._locks = {} + self._requeue_interval = requeue_interval + self._is_stopped = False + + if self._requeue_interval != 0: + self._regular_requeue_task = \ + self._loop.create_task(self._requeue_regularly()) + else: + self._regular_requeue_task = None async def _load_scripts(self, primary): @@ -82,6 +97,9 @@ async def _put_lua(self, task_id, task_payload): ) def put(self, task, method='lua'): + if self._is_stopped: + raise exceptions.Stopped + task_id = self._task_class.generate_id() task_payload = self._task_class.format_payload(task) @@ -107,6 +125,9 @@ async def _get_nowait(self, ack_info, script='get_nowait_l', key='fifo'): raise exceptions.Empty(result[0]) async def get_nowait(self): + if self._is_stopped: + raise exceptions.Stopped + ack_info = self._task_class.generate_ack_info() try: @@ -117,6 +138,9 @@ async def get_nowait(self): return await self._get_nowait(ack_info, 'get_nowait', 'queue') async def get(self, retry_interval=1): + if self._is_stopped: + raise exceptions.Stopped + while self._loop.is_running(): await self._redis.brpoplpush(self._keys['queue'], self._keys['fifo'], @@ -145,12 +169,18 @@ async def _ack_lua(self, task_id): ) def _ack(self, task_id, method='multi'): + if self._is_stopped: + raise exceptions.Stopped + if method == 'multi': return self._ack_pipe(task_id) elif method == 'lua': return self._ack_lua(task_id) async def _fail(self, task_id): + if self._is_stopped: + raise exceptions.Stopped + if 'fail' not in self._lua_sha: await self._load_scripts('fail') @@ -160,12 +190,44 @@ async def _fail(self, task_id): args=[task_id], ) - async def _requeue(self, before=None): + async def requeue(self, before=-1): + if self._is_stopped: + raise exceptions.Stopped + if 'requeue' not in self._lua_sha: await self._load_scripts('requeue') + if 'check_requeue_timestamp' not in self._lua_sha: + await self._load_scripts('check_requeue_timestamp') + + now = int(timeit.default_timer() * 1000) + results = await self._redis.evalsha( + self._lua_sha['check_requeue_timestamp'], + keys=[self._keys['last_requeue']], + args=[now, self._requeue_interval], + ) + if results[0] == b'error': + return results + return await self._redis.evalsha( self._lua_sha['requeue'], keys=[self._keys['ack'], self._keys['queue']], args=[before], ) + + def stop(self): + if self._is_stopped: + raise exceptions.Stopped + + if self._regular_requeue_task is not None: + self._regular_requeue_task.cancel() + + self._is_stopped = True + + async def _requeue_regularly(self): + before = int(timeit.default_timer() * 1000) + while (not self._redis.closed and self._loop.is_running() + and not self._is_stopped): + await asyncio.sleep(self._requeue_interval / 1000) + await self.requeue(before) + before += self._requeue_interval diff --git a/test/test_requeueing.py b/test/test_requeueing.py new file mode 100644 index 0000000..ae42571 --- /dev/null +++ b/test/test_requeueing.py @@ -0,0 +1,67 @@ +import timeit +import asyncio +import pytest +import aioredis +import aioredisqueue + +@pytest.mark.asyncio +async def test_requeue_method(): + redis = await aioredis.create_redis(('localhost', 6379), db=0) + queue = aioredisqueue.queue.Queue(redis) + + await queue.put(b'payload') + await queue.get() + main_queue_len = await redis.llen(queue._keys['queue']) + ack_len = await redis.hlen(queue._keys['ack']) + assert ack_len >= 1 + await redis.delete(queue._keys['last_requeue']) + now = int(timeit.default_timer() * 1000) + results = await queue.requeue() + assert results[0] == b'ok' + assert len(results[1]) == ack_len + assert len(results[3]) == ack_len + assert results[2][-1] == main_queue_len + ack_len + assert int(await redis.get(queue._keys['last_requeue'])) == now + + queue.stop() + redis.close() + await redis.wait_closed() + + +@pytest.mark.asyncio +async def test_too_early(): + redis = await aioredis.create_redis(('localhost', 6379), db=0) + queue = aioredisqueue.queue.Queue(redis) + + now = int(timeit.default_timer() * 1000) + await redis.set(queue._keys['last_requeue'], now) + results = await queue.requeue() + assert results[0] == b'error' + + queue.stop() + redis.close() + await redis.wait_closed() + + +@pytest.mark.asyncio +async def test_in_background(): + redis = await aioredis.create_redis(('localhost', 6379), db=0) + requeue_interval = 100 + eps = 0.01 + queue = aioredisqueue.queue.Queue(redis, requeue_interval=requeue_interval) + + await redis.delete(queue._keys['last_requeue']) + + await queue.put(b'payload') + await asyncio.sleep(eps) + task = await queue.get() + + await asyncio.sleep(requeue_interval / 1000) + assert await redis.hget(queue._keys['ack'], task.id) is not None + + await asyncio.sleep(requeue_interval / 1000) + assert await redis.hget(queue._keys['ack'], task.id) is None + + queue.stop() + redis.close() + await redis.wait_closed() diff --git a/test/test_task_usage.py b/test/test_task_usage.py index 8e82abc..e7f80b5 100644 --- a/test/test_task_usage.py +++ b/test/test_task_usage.py @@ -1,5 +1,5 @@ -import pytest import asyncio +import pytest import aioredis import aioredisqueue @@ -14,6 +14,10 @@ async def test_basic_put_get_and_get(): result = await queue.get() assert result.payload == message + queue.stop() + r.close() + await r.wait_closed() + async def get_ack(queue): task = await queue.get() @@ -43,6 +47,11 @@ async def test_ack(): await get_ack(queue) + queue.stop() + r.close() + await r.wait_closed() + + @pytest.mark.asyncio async def test_fail_ack(): r = await aioredis.create_redis(('localhost', 6379), db=0) @@ -51,3 +60,7 @@ async def test_fail_ack(): message = b'payload' await queue.put(message) await get_fail_ack(queue) + + queue.stop() + r.close() + await r.wait_closed()