diff --git a/arq/connections.py b/arq/connections.py index c1058890..ab482b66 100644 --- a/arq/connections.py +++ b/arq/connections.py @@ -13,8 +13,16 @@ from redis.asyncio.sentinel import Sentinel from redis.exceptions import RedisError, WatchError -from .constants import default_queue_name, expires_extra_ms, job_key_prefix, result_key_prefix +from .constants import ( + default_queue_name, + expires_extra_ms, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) from .jobs import Deserializer, Job, JobDef, JobResult, Serializer, deserialize_job, serialize_job +from .lua_script import publish_job_lua from .utils import timestamp_ms, to_ms, to_unix_ms logger = logging.getLogger('arq.connections') @@ -114,6 +122,7 @@ def __init__( if pool_or_conn: kwargs['connection_pool'] = pool_or_conn self.expires_extra_ms = expires_extra_ms + self.publish_job_sha = None super().__init__(**kwargs) async def enqueue_job( @@ -126,6 +135,7 @@ async def enqueue_job( _defer_by: Union[None, int, float, timedelta] = None, _expires: Union[None, int, float, timedelta] = None, _job_try: Optional[int] = None, + _use_stream: bool = False, **kwargs: Any, ) -> Optional[Job]: """ @@ -145,6 +155,7 @@ async def enqueue_job( """ if _queue_name is None: _queue_name = self.default_queue_name + job_id = _job_id or uuid4().hex job_key = job_key_prefix + job_id if _defer_until and _defer_by: @@ -153,6 +164,9 @@ async def enqueue_job( defer_by_ms = to_ms(_defer_by) expires_ms = to_ms(_expires) + if _use_stream is True and self.publish_job_sha is None: + self.publish_job_sha = await self.script_load(publish_job_lua) # type: ignore[no-untyped-call] + async with self.pipeline(transaction=True) as pipe: await pipe.watch(job_key) if await pipe.exists(job_key, result_key_prefix + job_id): @@ -172,7 +186,24 @@ async def enqueue_job( job = serialize_job(function, args, kwargs, _job_try, enqueue_time_ms, serializer=self.job_serializer) pipe.multi() pipe.psetex(job_key, expires_ms, job) - pipe.zadd(_queue_name, {job_id: score}) + + if _use_stream is False: + pipe.zadd(_queue_name, {job_id: score}) + else: + stream_key = _queue_name + stream_key_suffix + job_message_id_key = job_message_id_prefix + job_id + + pipe.evalsha( + self.publish_job_sha, + 2, + # keys + stream_key, + job_message_id_key, + # args + job_id, + str(enqueue_time_ms), + str(expires_ms), + ) try: await pipe.execute() except WatchError: @@ -180,6 +211,12 @@ async def enqueue_job( return None return Job(job_id, redis=self, _queue_name=_queue_name, _deserializer=self.job_deserializer) + async def get_stream_size(self, queue_name: str | None = None, include_delayed_tasks: bool = True) -> int: + if queue_name is None: + queue_name = self.default_queue_name + + return await self.xlen(queue_name + stream_key_suffix) + async def _get_job_result(self, key: bytes) -> JobResult: job_id = key[len(result_key_prefix) :].decode() job = Job(job_id, self, _deserializer=self.job_deserializer) diff --git a/arq/constants.py b/arq/constants.py index 84c009aa..a4ed1cf3 100644 --- a/arq/constants.py +++ b/arq/constants.py @@ -1,9 +1,12 @@ default_queue_name = 'arq:queue' job_key_prefix = 'arq:job:' in_progress_key_prefix = 'arq:in-progress:' +job_message_id_prefix = 'arq:message-id:' result_key_prefix = 'arq:result:' retry_key_prefix = 'arq:retry:' abort_jobs_ss = 'arq:abort' +stream_key_suffix = ':stream' +default_consumer_group = 'arq:consumers' # age of items in the abort_key sorted set after which they're deleted abort_job_max_age = 60 health_check_key_suffix = ':health-check' diff --git a/arq/jobs.py b/arq/jobs.py index 15b7231e..fc853bbe 100644 --- a/arq/jobs.py +++ b/arq/jobs.py @@ -9,8 +9,17 @@ from redis.asyncio import Redis -from .constants import abort_jobs_ss, default_queue_name, in_progress_key_prefix, job_key_prefix, result_key_prefix -from .utils import ms_to_datetime, poll, timestamp_ms +from .constants import ( + abort_jobs_ss, + default_queue_name, + in_progress_key_prefix, + job_key_prefix, + job_message_id_prefix, + result_key_prefix, + stream_key_suffix, +) +from .lua_script import get_job_from_stream_lua +from .utils import _list_to_dict, ms_to_datetime, poll, timestamp_ms logger = logging.getLogger('arq.jobs') @@ -105,7 +114,8 @@ async def result( async with self._redis.pipeline(transaction=True) as tr: tr.get(result_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - v, s = await tr.execute() + tr.get(job_message_id_prefix + self.job_id) + v, s, m = await tr.execute() if v: info = deserialize_result(v, deserializer=self._deserializer) @@ -115,7 +125,7 @@ async def result( raise info.result else: raise SerializationError(info.result) - elif s is None: + elif s is None and m is None: raise ResultNotFound( 'Not waiting for job result because the job is not in queue. ' 'Is the worker function configured to keep result?' @@ -134,8 +144,23 @@ async def info(self) -> Optional[JobDef]: if v: info = deserialize_job(v, deserializer=self._deserializer) if info: - s = await self._redis.zscore(self._queue_name, self.job_id) - info.score = None if s is None else int(s) + async with self._redis.pipeline(transaction=True) as tr: + tr.zscore(self._queue_name, self.job_id) + tr.eval( + get_job_from_stream_lua, + 2, + self._queue_name + stream_key_suffix, + job_message_id_prefix + self.job_id, + ) + delayed_score, job_info = await tr.execute() + + if delayed_score: + info.score = int(delayed_score) + elif job_info: + _, job_info_payload = job_info + info.score = int(_list_to_dict(job_info_payload)[b'score']) + else: + info.score = None return info async def result_info(self) -> Optional[JobResult]: @@ -157,12 +182,15 @@ async def status(self) -> JobStatus: tr.exists(result_key_prefix + self.job_id) tr.exists(in_progress_key_prefix + self.job_id) tr.zscore(self._queue_name, self.job_id) - is_complete, is_in_progress, score = await tr.execute() + tr.exists(job_message_id_prefix + self.job_id) + is_complete, is_in_progress, score, queued = await tr.execute() if is_complete: return JobStatus.complete elif is_in_progress: return JobStatus.in_progress + elif queued: + return JobStatus.queued elif score: return JobStatus.deferred if score > timestamp_ms() else JobStatus.queued else: diff --git a/arq/lua_script.py b/arq/lua_script.py new file mode 100644 index 00000000..f7ec24e0 --- /dev/null +++ b/arq/lua_script.py @@ -0,0 +1,24 @@ +publish_job_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] +local job_id = ARGV[1] +local score = ARGV[2] +local job_message_id_expire_ms = ARGV[3] +local message_id = redis.call('xadd', stream_key, '*', 'job_id', job_id, 'score', score) +redis.call('set', job_message_id_key, message_id, 'px', job_message_id_expire_ms) +return message_id +""" + +get_job_from_stream_lua = """ +local stream_key = KEYS[1] +local job_message_id_key = KEYS[2] +local message_id = redis.call('get', job_message_id_key) +if message_id == false then + return nil +end +local job = redis.call('xrange', stream_key, message_id, message_id) +if job == nil then + return nil +end +return job[1] +""" diff --git a/arq/utils.py b/arq/utils.py index 2cbde056..f7ad1aeb 100644 --- a/arq/utils.py +++ b/arq/utils.py @@ -148,3 +148,7 @@ def import_string(dotted_path: str) -> Any: return getattr(module, class_name) except AttributeError as e: raise ImportError(f'Module "{module_path}" does not define a "{class_name}" attribute') from e + + +def _list_to_dict(input_list: list[Any]) -> dict[Any, Any]: + return dict(zip(input_list[::2], input_list[1::2], strict=True)) diff --git a/arq/worker.py b/arq/worker.py index f1e613c9..29a174ad 100644 --- a/arq/worker.py +++ b/arq/worker.py @@ -3,12 +3,14 @@ import inspect import logging import signal +from contextlib import suppress from dataclasses import dataclass from datetime import datetime, timedelta, timezone from functools import partial from signal import Signals from time import time from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union, cast +from uuid import uuid4 from redis.exceptions import ResponseError, WatchError @@ -19,15 +21,19 @@ from .constants import ( abort_job_max_age, abort_jobs_ss, + default_consumer_group, default_queue_name, expires_extra_ms, health_check_key_suffix, in_progress_key_prefix, job_key_prefix, + job_message_id_prefix, keep_cronjob_progress, result_key_prefix, retry_key_prefix, + stream_key_suffix, ) +from .lua_script import publish_job_lua from .utils import ( args_to_string, import_string, @@ -57,6 +63,13 @@ class Function: max_tries: Optional[int] +@dataclass +class JobMetaInfo: + job_id: str + message_id: str | None = None + score: int | None = None + + def func( coroutine: Union[str, Function, 'WorkerCoroutine'], *, @@ -188,6 +201,9 @@ def __init__( functions: Sequence[Union[Function, 'WorkerCoroutine']] = (), *, queue_name: Optional[str] = default_queue_name, + use_stream: bool = False, + consumer_group_name: str = default_consumer_group, + worker_id: Optional[str] = None, cron_jobs: Optional[Sequence[CronJob]] = None, redis_settings: Optional[RedisSettings] = None, redis_pool: Optional[ArqRedis] = None, @@ -204,6 +220,7 @@ def __init__( keep_result: 'SecondsTimedelta' = 3600, keep_result_forever: bool = False, poll_delay: 'SecondsTimedelta' = 0.5, + stream_block: 'SecondsTimedelta' = 0.5, queue_read_limit: Optional[int] = None, max_tries: int = 5, health_check_interval: 'SecondsTimedelta' = 3600, @@ -224,7 +241,11 @@ def __init__( queue_name = redis_pool.default_queue_name else: raise ValueError('If queue_name is absent, redis_pool must be present.') + self.queue_name = queue_name + self.use_stream = use_stream + self.consumer_group_name = consumer_group_name + self.worker_id = worker_id or str(uuid4().hex) self.cron_jobs: List[CronJob] = [] if cron_jobs is not None: if not all(isinstance(cj, CronJob) for cj in cron_jobs): @@ -248,6 +269,7 @@ def __init__( self.keep_result_s = to_seconds(keep_result) self.keep_result_forever = keep_result_forever self.poll_delay_s = to_seconds(poll_delay) + self.stream_block_s = to_seconds(stream_block) self.queue_read_limit = queue_read_limit or max(max_jobs * 5, 100) self._queue_read_offset = 0 self.max_tries = max_tries @@ -295,6 +317,7 @@ def __init__( self.job_deserializer = job_deserializer self.expires_extra_ms = expires_extra_ms self.log_results = log_results + self.publish_job_sha = None # default to system timezone self.timezone = datetime.now().astimezone().tzinfo if timezone is None else timezone @@ -351,23 +374,128 @@ async def main(self) -> None: expires_extra_ms=self.expires_extra_ms, ) + if self.use_stream: + self.publish_job_sha = await self.pool.script_load(publish_job_lua) # type: ignore[no-untyped-call] + logger.info('Starting worker for %d functions: %s', len(self.functions), ', '.join(self.functions)) await log_redis_info(self.pool, logger.info) self.ctx['redis'] = self.pool if self.on_startup: await self.on_startup(self.ctx) + if self.use_stream: + await self.create_consumer_group() + await self.run_stream_reader() + else: + await self._run_pool_iteration() + + async def run_stream_reader(self) -> None: + while True: + await self._read_stream_iteration() + + if await self._is_burst_broken() is True: + return None + + async def create_consumer_group(self) -> None: + with suppress(ResponseError): + await self.pool.xgroup_create( + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + id='0', + mkstream=True, + ) + + async def _read_stream_iteration(self) -> None: + """ + Get ids of pending jobs from the stream and start those jobs, remove + any finished tasks from self.tasks. + """ + count = self.max_jobs - self.job_counter + if self.burst and self.max_burst_jobs >= 0: + burst_jobs_remaining = self.max_burst_jobs - self._jobs_started() + if burst_jobs_remaining < 1: + return + count = min(burst_jobs_remaining, count) + if self.allow_pick_jobs: + if self.job_counter < self.max_jobs: + stream_msgs = await self._get_idle_tasks(count) + msgs_count = sum([len(msgs) for _, msgs in stream_msgs]) + + count -= msgs_count + + if count > 0: + stream_msgs.extend( + await self.pool.xreadgroup( + groupname=self.consumer_group_name, + consumername=self.worker_id, + streams={self.queue_name + stream_key_suffix: '>'}, + count=count, + block=int(max(self.stream_block_s * 1000, 1)), + ) + ) + + jobs = [] + + for _, msgs in stream_msgs: + for msg_id, job in msgs: + jobs.append( + JobMetaInfo( + message_id=msg_id.decode(), + job_id=job[b'job_id'].decode(), + score=int(job[b'score']), + ) + ) + + await self.start_jobs(jobs) + + if self.allow_abort_jobs: + await self._cancel_aborted_jobs() + + for job_id, t in list(self.tasks.items()): + if t.done(): + del self.tasks[job_id] + # required to make sure errors in run_job get propagated + t.result() + + await self.heart_beat() + + async def _get_idle_tasks(self, count: int) -> list[tuple[bytes, list[Any]]]: + resp = await self.pool.xautoclaim( + self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=self.worker_id, + min_idle_time=int(self.in_progress_timeout_s * 1000), + count=count, + ) + + if not resp: + return [] + + _, msgs, __ = resp + if not msgs: + return [] + + # cast to the same format as the xreadgroup response + return [((self.queue_name + stream_key_suffix).encode(), msgs)] + + async def _run_pool_iteration(self) -> None: async for _ in poll(self.poll_delay_s): await self._poll_iteration() - if self.burst: - if 0 <= self.max_burst_jobs <= self._jobs_started(): - await asyncio.gather(*self.tasks.values()) - return None - queued_jobs = await self.pool.zcard(self.queue_name) - if queued_jobs == 0: - await asyncio.gather(*self.tasks.values()) - return None + if await self._is_burst_broken() is True: + return None + + async def _is_burst_broken(self) -> bool: + if self.burst: + if 0 <= self.max_burst_jobs <= self._jobs_started(): + await asyncio.gather(*self.tasks.values()) + return True + queued_jobs = await self.pool.zcard(self.queue_name) + if queued_jobs == 0: + await asyncio.gather(*self.tasks.values()) + return True + + return False async def _poll_iteration(self) -> None: """ @@ -387,7 +515,8 @@ async def _poll_iteration(self) -> None: self.queue_name, min=float('-inf'), start=self._queue_read_offset, num=count, max=now ) - await self.start_jobs(job_ids) + jobs = [JobMetaInfo(job_id=_id.decode()) for _id in job_ids] + await self.start_jobs(jobs) if self.allow_abort_jobs: await self._cancel_aborted_jobs() @@ -428,11 +557,13 @@ def _release_sem_dec_counter_on_complete(self) -> None: self.job_counter = self.job_counter - 1 self.sem.release() - async def start_jobs(self, job_ids: List[bytes]) -> None: + async def start_jobs(self, jobs: list[JobMetaInfo]) -> None: """ For each job id, get the job definition, check it's not running and start it in a task """ - for job_id_b in job_ids: + for job in jobs: + job_id = job.job_id + await self.sem.acquire() if self.job_counter >= self.max_jobs: @@ -441,12 +572,14 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: self.job_counter = self.job_counter + 1 - job_id = job_id_b.decode() in_progress_key = in_progress_key_prefix + job_id async with self.pool.pipeline(transaction=True) as pipe: await pipe.watch(in_progress_key) ongoing_exists = await pipe.exists(in_progress_key) - score = await pipe.zscore(self.queue_name, job_id) + if self.use_stream: + score = job.score + else: + score = await pipe.zscore(self.queue_name, job_id) if ongoing_exists or not score or score > timestamp_ms(): # job already started elsewhere, or already finished and removed from queue # if score > ts_now, @@ -466,11 +599,13 @@ async def start_jobs(self, job_ids: List[bytes]) -> None: self.sem.release() logger.debug('multi-exec error, job %s already started elsewhere', job_id) else: - t = self.loop.create_task(self.run_job(job_id, int(score))) + t = self.loop.create_task(self.run_job(job, score)) t.add_done_callback(lambda _: self._release_sem_dec_counter_on_complete()) self.tasks[job_id] = t - async def run_job(self, job_id: str, score: int) -> None: # noqa: C901 + # sentry_sdk/integrations/arq.py patch_run_job needs to be updated + async def run_job(self, job_info: JobMetaInfo, score: int) -> None: # noqa: C901 + job_id = job_info.job_id start_ms = timestamp_ms() async with self.pool.pipeline(transaction=True) as pipe: pipe.get(job_key_prefix + job_id) @@ -504,7 +639,7 @@ async def job_failed(exc: BaseException) -> None: queue_name=self.queue_name, job_id=job_id, ) - await asyncio.shield(self.finish_failed_job(job_id, result_data_)) + await asyncio.shield(self.finish_failed_job(job_id, job_info.message_id, result_data_)) if not v: logger.warning('job %s expired', job_id) @@ -561,7 +696,7 @@ async def job_failed(exc: BaseException) -> None: job_id=job_id, serializer=self.job_serializer, ) - return await asyncio.shield(self.finish_failed_job(job_id, result_data)) + return await asyncio.shield(self.finish_failed_job(job_id, job_info.message_id, result_data)) result = no_result exc_extra = None @@ -662,6 +797,8 @@ async def job_failed(exc: BaseException) -> None: await asyncio.shield( self.finish_job( job_id, + job_info.message_id, + score, finish, result_data, result_timeout_s, @@ -677,6 +814,8 @@ async def job_failed(exc: BaseException) -> None: async def finish_job( self, job_id: str, + message_id: str | None, + score: int, finish: bool, result_data: Optional[bytes], result_timeout_s: Optional[float], @@ -687,33 +826,80 @@ async def finish_job( async with self.pool.pipeline(transaction=True) as tr: delete_keys = [] in_progress_key = in_progress_key_prefix + job_id + stream_key = self.queue_name + stream_key_suffix if keep_in_progress is None: delete_keys += [in_progress_key] else: tr.pexpire(in_progress_key, to_ms(keep_in_progress)) + if self.use_stream: + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) + if finish: if result_data: expire = None if keep_result_forever else result_timeout_s tr.set(result_key_prefix + job_id, result_data, px=to_ms(expire)) - delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id] + delete_keys += [retry_key_prefix + job_id, job_key_prefix + job_id, job_message_id_prefix + job_id] tr.zrem(abort_jobs_ss, job_id) tr.zrem(self.queue_name, job_id) elif incr_score: - tr.zincrby(self.queue_name, incr_score, job_id) + if self.use_stream: + job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms + tr.evalsha( + self.publish_job_sha, + 2, + # keys + stream_key, + job_message_id_prefix + job_id, + # args + job_id, + str(score + incr_score), + str(job_message_id_expire), + ) + else: + tr.zincrby(self.queue_name, incr_score, job_id) + else: + if self.use_stream: + job_message_id_expire = score - timestamp_ms() + self.expires_extra_ms + tr.evalsha( + self.publish_job_sha, + 2, + # keys + stream_key, + job_message_id_prefix + job_id, + # args + job_id, + str(score), + str(job_message_id_expire), + ) if delete_keys: tr.delete(*delete_keys) await tr.execute() - async def finish_failed_job(self, job_id: str, result_data: Optional[bytes]) -> None: + async def finish_failed_job(self, job_id: str, message_id: str | None, result_data: Optional[bytes]) -> None: + stream_key = self.queue_name + stream_key_suffix async with self.pool.pipeline(transaction=True) as tr: tr.delete( retry_key_prefix + job_id, in_progress_key_prefix + job_id, job_key_prefix + job_id, + job_message_id_prefix + job_id, ) tr.zrem(abort_jobs_ss, job_id) - tr.zrem(self.queue_name, job_id) + if self.use_stream: + tr.xack( + stream_key, + self.consumer_group_name, + message_id, + ) + tr.xdel(stream_key, message_id) + else: + tr.zrem(self.queue_name, job_id) # result_data would only be None if serializing the result fails keep_result = self.keep_result_forever or self.keep_result_s > 0 if result_data is not None and keep_result: # pragma: no branch @@ -869,6 +1055,15 @@ async def close(self) -> None: await self.pool.delete(self.health_check_key) if self.on_shutdown: await self.on_shutdown(self.ctx) + + if self.use_stream: + # delete consumer + await self.pool.xgroup_delconsumer( # type: ignore[no-untyped-call] + name=self.queue_name + stream_key_suffix, + groupname=self.consumer_group_name, + consumername=self.worker_id, + ) + await self.pool.close(close_connection_pool=True) self._pool = None diff --git a/tests/test_worker.py b/tests/test_worker.py index 93fbc7f0..55b322bb 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -17,6 +17,7 @@ from arq.worker import ( FailedJobs, JobExecutionFailed, + JobMetaInfo, Retry, RetryJob, Worker, @@ -179,10 +180,10 @@ async def retry_job(ctx): assert worker_two.jobs_failed == 0 assert worker_two.jobs_retried == 0 - await worker_one.start_jobs([job_id.encode()]) + await worker_one.start_jobs([JobMetaInfo(job_id=job_id)]) await asyncio.gather(*worker_one.tasks.values()) - await worker_two.start_jobs([job_id.encode()]) + await worker_two.start_jobs([JobMetaInfo(job_id=job_id)]) await asyncio.gather(*worker_two.tasks.values()) assert worker_one.jobs_complete == 0 @@ -846,7 +847,7 @@ async def foo(ctx, v): caplog.set_level(logging.DEBUG, logger='arq.worker') await arq_redis.enqueue_job('foo', 1, _job_id='testing') worker: Worker = worker(functions=[func(foo, name='foo')]) - await asyncio.gather(*[worker.start_jobs([b'testing']) for _ in range(5)]) + await asyncio.gather(*[worker.start_jobs([JobMetaInfo(job_id='testing')]) for _ in range(5)]) # debug(caplog.text) await worker.main() assert c == 1