Skip to content

Commit 5bc175d

Browse files
authored
add on_job_start and on_job_end hooks (#274)
1 parent e015d3d commit 5bc175d

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

arq/worker.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ class Worker:
138138
:param burst: whether to stop the worker once all jobs have been run
139139
:param on_startup: coroutine function to run at startup
140140
:param on_shutdown: coroutine function to run at shutdown
141+
:param on_job_start: coroutine function to run on job start
142+
:param on_job_end: coroutine function to run on job end
141143
:param handle_signals: default true, register signal handlers,
142144
set to false when running inside other async framework
143145
:param max_jobs: maximum number of jobs to run at a time
@@ -169,6 +171,8 @@ def __init__(
169171
burst: bool = False,
170172
on_startup: Optional['StartupShutdown'] = None,
171173
on_shutdown: Optional['StartupShutdown'] = None,
174+
on_job_start: Optional['StartupShutdown'] = None,
175+
on_job_end: Optional['StartupShutdown'] = None,
172176
handle_signals: bool = True,
173177
max_jobs: int = 10,
174178
job_timeout: 'SecondsTimedelta' = 300,
@@ -202,6 +206,8 @@ def __init__(
202206
self.burst = burst
203207
self.on_startup = on_startup
204208
self.on_shutdown = on_shutdown
209+
self.on_job_start = on_job_start
210+
self.on_job_end = on_job_end
205211
self.sem = asyncio.BoundedSemaphore(max_jobs)
206212
self.job_timeout_s = to_seconds(job_timeout)
207213
self.keep_result_s = to_seconds(keep_result)
@@ -509,6 +515,10 @@ async def job_failed(exc: BaseException) -> None:
509515
'score': score,
510516
}
511517
ctx = {**self.ctx, **job_ctx}
518+
519+
if self.on_job_start:
520+
await self.on_job_start(ctx)
521+
512522
start_ms = timestamp_ms()
513523
success = False
514524
try:
@@ -585,6 +595,9 @@ async def job_failed(exc: BaseException) -> None:
585595
serializer=self.job_serializer,
586596
)
587597

598+
if self.on_job_end:
599+
await self.on_job_end(ctx)
600+
588601
await asyncio.shield(
589602
self.finish_job(
590603
job_id, finish, result_data, result_timeout_s, keep_result_forever, incr_score, keep_in_progress,

tests/test_worker.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,3 +835,32 @@ async def longfunc(ctx):
835835
assert worker.jobs_retried == 0
836836
log = re.sub(r'\d+.\d\ds', 'X.XXs', '\n'.join(r.message for r in caplog.records))
837837
assert 'X.XXs ! testing:longfunc failed, TimeoutError:' in log
838+
839+
840+
async def test_on_job(arq_redis: ArqRedis, worker):
841+
result = {'called': 0}
842+
843+
async def on_start(ctx):
844+
assert ctx['job_id'] == 'testing'
845+
result['called'] += 1
846+
847+
async def on_end(ctx):
848+
assert ctx['job_id'] == 'testing'
849+
result['called'] += 1
850+
851+
async def test(ctx):
852+
return
853+
854+
await arq_redis.enqueue_job('func', _job_id='testing')
855+
worker: Worker = worker(
856+
functions=[func(test, name='func')], on_job_start=on_start, on_job_end=on_end, job_timeout=0.2, poll_delay=0.1,
857+
)
858+
assert worker.jobs_complete == 0
859+
assert worker.jobs_failed == 0
860+
assert worker.jobs_retried == 0
861+
assert result['called'] == 0
862+
await worker.main()
863+
assert worker.jobs_complete == 1
864+
assert worker.jobs_failed == 0
865+
assert worker.jobs_retried == 0
866+
assert result['called'] == 2

0 commit comments

Comments
 (0)