Skip to content

Commit ab1a21a

Browse files
committed
Feat(Stream): Use redis stream
1 parent 7a911f3 commit ab1a21a

File tree

7 files changed

+325
-33
lines changed

7 files changed

+325
-33
lines changed

arq/connections.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,16 @@
1313
from redis.asyncio.sentinel import Sentinel
1414
from redis.exceptions import RedisError, WatchError
1515

16-
from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix
16+
from .constants import (
17+
default_queue_name,
18+
expires_extra_ms,
19+
job_key_prefix,
20+
job_message_id_prefix,
21+
result_key_prefix,
22+
stream_key_suffix,
23+
)
1724
from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job
25+
from .lua_script import publish_job_lua
1826
from .utils import timestamp_ms, to_ms, to_unix_ms
1927

2028
logger = logging.getLogger('arq.connections')
@@ -114,6 +122,7 @@ def __init__(
114122
if pool_or_conn:
115123
kwargs['connection_pool'] = pool_or_conn
116124
self.expires_extra_ms = expires_extra_ms
125+
self.publish_job_sha = None
117126
super().__init__(**kwargs)
118127

119128
async def enqueue_job(
@@ -126,6 +135,7 @@ async def enqueue_job(
126135
_defer_by: Union[None, int, float, timedelta] = None,
127136
_expires: Union[None, int, float, timedelta] = None,
128137
_job_try: Optional[int] = None,
138+
_use_stream: bool = False,
129139
**kwargs: Any,
130140
) -> Optional[Job]:
131141
"""
@@ -145,6 +155,7 @@ async def enqueue_job(
145155
"""
146156
if _queue_name is None:
147157
_queue_name = self.default_queue_name
158+
148159
job_id = _job_id or uuid4().hex
149160
job_key = job_key_prefix + job_id
150161
if _defer_until and _defer_by:
@@ -153,6 +164,9 @@ async def enqueue_job(
153164
defer_by_ms = to_ms(_defer_by)
154165
expires_ms = to_ms(_expires)
155166

167+
if _use_stream is True and self.publish_job_sha is None:
168+
self.publish_job_sha = await self.script_load(publish_job_lua) # type: ignore[no-untyped-call]
169+
156170
async with self.pipeline(transaction=True) as pipe:
157171
await pipe.watch(job_key)
158172
if await pipe.exists(job_key, result_key_prefix + job_id):
@@ -172,14 +186,37 @@ async def enqueue_job(
172186
job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer)
173187
pipe.multi()
174188
pipe.psetex(job_key, expires_ms, job)
175-
pipe.zadd(_queue_name, {job_id: score})
189+
190+
if _use_stream is False:
191+
pipe.zadd(_queue_name, {job_id: score})
192+
else:
193+
stream_key = _queue_name + stream_key_suffix
194+
job_message_id_key = job_message_id_prefix + job_id
195+
196+
pipe.evalsha(
197+
self.publish_job_sha,
198+
2,
199+
# keys
200+
stream_key,
201+
job_message_id_key,
202+
# args
203+
job_id,
204+
str(enqueue_time_ms),
205+
str(expires_ms),
206+
)
176207
try:
177208
await pipe.execute()
178209
except WatchError:
179210
# job got enqueued since we checked 'job_exists'
180211
return None
181212
return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer)
182213

214+
async def get_stream_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int:
215+
if queue_name is None:
216+
queue_name = self.default_queue_name
217+
218+
return await self.xlen(queue_name + stream_key_suffix)
219+
183220
async def _get_job_result(self, key: bytes) -> JobResult:
184221
job_id = key[len(result_key_prefix) :].decode()
185222
job = Job(job_id, self, _deserializer=self.job_deserializer)

arq/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
default_queue_name = 'arq:queue'
22
job_key_prefix = 'arq:job:'
33
in_progress_key_prefix = 'arq:in-progress:'
4+
job_message_id_prefix = 'arq:message-id:'
45
result_key_prefix = 'arq:result:'
56
retry_key_prefix = 'arq:retry:'
67
abort_jobs_ss = 'arq:abort'
8+
stream_key_suffix = ':stream'
9+
default_consumer_group = 'arq:consumers'
710
# age of items in the abort_key sorted set after which they're deleted
811
abort_job_max_age = 60
912
health_check_key_suffix = ':health-check'

arq/jobs.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,17 @@
99

1010
from redis.asyncio import Redis
1111

12-
from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix
13-
from .utils import ms_to_datetime, poll, timestamp_ms
12+
from .constants import (
13+
abort_jobs_ss,
14+
default_queue_name,
15+
in_progress_key_prefix,
16+
job_key_prefix,
17+
job_message_id_prefix,
18+
result_key_prefix,
19+
stream_key_suffix,
20+
)
21+
from .lua_script import get_job_from_stream_lua
22+
from .utils import _list_to_dict, ms_to_datetime, poll, timestamp_ms
1423

1524
logger = logging.getLogger('arq.jobs')
1625

@@ -105,7 +114,8 @@ async def result(
105114
async with self._redis.pipeline(transaction=True) as tr:
106115
tr.get(result_key_prefix + self.job_id)
107116
tr.zscore(self._queue_name, self.job_id)
108-
v, s = await tr.execute()
117+
tr.get(job_message_id_prefix + self.job_id)
118+
v, s, m = await tr.execute()
109119

110120
if v:
111121
info = deserialize_result(v, deserializer=self._deserializer)
@@ -115,7 +125,7 @@ async def result(
115125
raise info.result
116126
else:
117127
raise SerializationError(info.result)
118-
elif s is None:
128+
elif s is None and m is None:
119129
raise ResultNotFound(
120130
'Not waiting for job result because the job is not in queue. '
121131
'Is the worker function configured to keep result?'
@@ -134,8 +144,23 @@ async def info(self) -> Optional[JobDef]:
134144
if v:
135145
info = deserialize_job(v, deserializer=self._deserializer)
136146
if info:
137-
s = await self._redis.zscore(self._queue_name, self.job_id)
138-
info.score = None if s is None else int(s)
147+
async with self._redis.pipeline(transaction=True) as tr:
148+
tr.zscore(self._queue_name, self.job_id)
149+
tr.eval(
150+
get_job_from_stream_lua,
151+
2,
152+
self._queue_name + stream_key_suffix,
153+
job_message_id_prefix + self.job_id,
154+
)
155+
delayed_score, job_info = await tr.execute()
156+
157+
if delayed_score:
158+
info.score = int(delayed_score)
159+
elif job_info:
160+
_, job_info_payload = job_info
161+
info.score = int(_list_to_dict(job_info_payload)[b'score'])
162+
else:
163+
info.score = None
139164
return info
140165

141166
async def result_info(self) -> Optional[JobResult]:
@@ -157,12 +182,15 @@ async def status(self) -> JobStatus:
157182
tr.exists(result_key_prefix + self.job_id)
158183
tr.exists(in_progress_key_prefix + self.job_id)
159184
tr.zscore(self._queue_name, self.job_id)
160-
is_complete, is_in_progress, score = await tr.execute()
185+
tr.exists(job_message_id_prefix + self.job_id)
186+
is_complete, is_in_progress, score, queued = await tr.execute()
161187

162188
if is_complete:
163189
return JobStatus.complete
164190
elif is_in_progress:
165191
return JobStatus.in_progress
192+
elif queued:
193+
return JobStatus.queued
166194
elif score:
167195
return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued
168196
else:

arq/lua_script.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
publish_job_lua = """
2+
local stream_key = KEYS[1]
3+
local job_message_id_key = KEYS[2]
4+
local job_id = ARGV[1]
5+
local score = ARGV[2]
6+
local job_message_id_expire_ms = ARGV[3]
7+
local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score)
8+
redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms)
9+
return message_id
10+
"""
11+
12+
get_job_from_stream_lua = """
13+
local stream_key = KEYS[1]
14+
local job_message_id_key = KEYS[2]
15+
local message_id = redis.call('get', job_message_id_key)
16+
if message_id == false then
17+
return nil
18+
end
19+
local job = redis.call('xrange', stream_key, message_id, message_id)
20+
if job == nil then
21+
return nil
22+
end
23+
return job[1]
24+
"""

arq/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,7 @@ def import_string(dotted_path: str) -> Any:
148148
return getattr(module, class_name)
149149
except AttributeError as e:
150150
raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e
151+
152+
153+
def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]:
154+
return dict(zip(input_list[::2], input_list[1::2], strict=True))

0 commit comments

Comments
 (0)