Skip to content
6 changes: 5 additions & 1 deletion aioredisqueue/load_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


__all__ = ('load_put', 'load_ack', 'load_fail', 'load_requeue',
'load_get_nowait', 'load_get_nowait_l')
'load_get_nowait', 'load_get_nowait_l', 'load_check_requeue_timestamp')


def _load(name):
Expand Down Expand Up @@ -39,3 +39,7 @@ def load_get_nowait():

def load_get_nowait_l():
return _load('get_nowait_l')


def load_check_requeue_timestamp():
return _load('check_requeue_timestamp')
11 changes: 11 additions & 0 deletions aioredisqueue/lua/check_requeue_timestamp.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
-- KEYS[1] last requeue timestamp
-- ARGV[1] current timestamp
-- ARGV[2] requeue interval

local last_requeue = tonumber(redis.call('get', KEYS[1]))
if last_requeue ~= nil and tonumber(ARGV[1]) - last_requeue <= tonumber(ARGV[2]) then
return {'error'}
else
redis.call('set', KEYS[1], ARGV[1])
return {'ok'}
end
4 changes: 2 additions & 2 deletions aioredisqueue/lua/requeue.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ local kvs = redis.call('HGETALL', KEYS[1])
for i = 1, #kvs, 2 do
job_id = kvs[i]
v = tonumber(kvs[i + 1])
if v < min_val then
if min_val == -1 or v < min_val then
r = redis.call('hdel', KEYS[1], job_id)
table.insert(results_hdel, r)
if r then
Expand All @@ -26,4 +26,4 @@ for i = 1, #kvs, 2 do
end
end

return {results_hdel, results_lpush, results_keys}
return {'ok', results_hdel, results_lpush, results_keys}
37 changes: 36 additions & 1 deletion aioredisqueue/queue.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import timeit

from . import task, load_scripts, exceptions

Expand All @@ -10,8 +11,10 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *,
fetching_fifo_key=None,
payloads_hash_key=None,
ack_hash_key=None,
last_requeue_key=None,
task_class=task.Task,
lua_sha=None,
requeue_interval=10000,
loop=None):

if main_queue_key is None:
Expand All @@ -26,6 +29,9 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *,
if ack_hash_key is None:
ack_hash_key = key_prefix + 'ack'

if last_requeue_key is None:
last_requeue_key = key_prefix + 'last_requeue'

if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -34,13 +40,19 @@ def __init__(self, redis, key_prefix='aioredisqueue_', *,
'fifo': fetching_fifo_key,
'payload': payloads_hash_key,
'ack': ack_hash_key,
'last_requeue': last_requeue_key,
}

self._redis = redis
self._loop = loop
self._task_class = task_class
self._lua_sha = lua_sha if lua_sha is not None else {}
self._locks = {}
self._requeue_interval = requeue_interval
self._requeueing_is_stopped = self._requeue_interval == 0

if self._requeue_interval != 0:
self._loop.create_task(self._requeue_periodically())

async def _load_scripts(self, primary):

Expand Down Expand Up @@ -160,12 +172,35 @@ async def _fail(self, task_id):
args=[task_id],
)

async def _requeue(self, before=None):
async def requeue(self, before=-1):
if 'requeue' not in self._lua_sha:
await self._load_scripts('requeue')

if 'check_requeue_timestamp' not in self._lua_sha:
await self._load_scripts('check_requeue_timestamp')

now = int(timeit.default_timer() * 1000)
results = await self._redis.evalsha(
self._lua_sha['check_requeue_timestamp'],
keys=[self._keys['last_requeue']],
args=[now, self._requeue_interval],
)
if results[0] == b'error':
return results

return await self._redis.evalsha(
self._lua_sha['requeue'],
keys=[self._keys['ack'], self._keys['queue']],
args=[before],
)

def stop_requeueing(self):
self._requeueing_is_stopped = True

async def _requeue_periodically(self):
before = int(timeit.default_timer() * 1000)
while (not self._redis.closed and self._loop.is_running()
and not self._requeueing_is_stopped):
await asyncio.sleep(self._requeue_interval / 1000)
await self.requeue(before)
before += self._requeue_interval
64 changes: 64 additions & 0 deletions test/test_requeueing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import aioredis
import aioredisqueue
import asyncio
import pytest
import timeit

@pytest.mark.asyncio
async def test_requeue_method():
r = await aioredis.create_redis(('localhost', 6379), db=0)
q = aioredisqueue.queue.Queue(r)

await q.put(b'payload')
await q.get()
main_queue_len = await r.llen(q._keys['queue'])
ack_len = await r.hlen(q._keys['ack'])
assert ack_len >= 1
await r.delete(q._keys['last_requeue'])
now = int(timeit.default_timer() * 1000)
results = await q.requeue()
assert results[0] == b'ok'
assert len(results[1]) == ack_len
assert len(results[3]) == ack_len
assert results[2][-1] == main_queue_len + ack_len
assert int(await r.get(q._keys['last_requeue'])) == now

r.close()
await r.wait_closed()


@pytest.mark.asyncio
async def test_too_early():
r = await aioredis.create_redis(('localhost', 6379), db=0)
q = aioredisqueue.queue.Queue(r)

now = int(timeit.default_timer() * 1000)
await r.set(q._keys['last_requeue'], now)
results = await q.requeue()
assert results[0] == b'error'

r.close()
await r.wait_closed()


@pytest.mark.asyncio
async def test_in_background():
r = await aioredis.create_redis(('localhost', 6379), db=0)
requeue_interval = 100
eps = 0.01
q = aioredisqueue.queue.Queue(r, requeue_interval=requeue_interval)

await r.delete(q._keys['last_requeue'])

await q.put(b'payload')
await asyncio.sleep(eps)
task = await q.get()

await asyncio.sleep(requeue_interval / 1000)
assert await r.hget(q._keys['ack'], task.id) is not None

await asyncio.sleep(requeue_interval / 1000)
assert await r.hget(q._keys['ack'], task.id) is None

r.close()
await r.wait_closed()