1717from ..const import (
1818 FUTURE_JOBS_KEY , NOTIFICATIONS_KEY , RUNNING_JOBS_KEY ,
1919 PERIODIC_TASKS_HASH_KEY , PERIODIC_TASKS_QUEUE_KEY ,
20- DEFAULT_ENQUEUE_JOB_RETRIES , ALL_BROKERS_HASH_KEY , ALL_BROKERS_ZSET_KEY
20+ DEFAULT_ENQUEUE_JOB_RETRIES , ALL_BROKERS_HASH_KEY , ALL_BROKERS_ZSET_KEY ,
21+ MAX_CONCURRENCY_KEY , CURRENT_CONCURRENCY_KEY ,
2122)
2223from ..utils import run_forever , call_with_retry
2324
@@ -57,10 +58,16 @@ def __init__(self, redis: Optional[StrictRedis]=None,
5758 self ._get_jobs_from_queue = self ._load_script (
5859 'get_jobs_from_queue.lua'
5960 )
61+ self ._remove_job_from_running = self ._load_script (
62+ 'remove_job_from_running.lua'
63+ )
6064 self ._move_future_jobs = self ._load_script ('move_future_jobs.lua' )
6165 self ._register_periodic_tasks = self ._load_script (
6266 'register_periodic_tasks.lua'
6367 )
68+ self ._set_concurrency_keys = self ._load_script (
69+ 'set_concurrency_keys.lua'
70+ )
6471 self ._reset ()
6572
6673 def _reset (self ):
@@ -107,7 +114,11 @@ def _run_script(self, script: Script, *args):
107114
108115 return rv
109116
110- def enqueue_jobs (self , jobs : Iterable [Job ]):
117+ def is_queue_empty (self , queue : str ) -> bool :
118+ """Return True if the provided queue is empty."""
119+ return self ._r .llen (self ._to_namespaced (queue )) == 0
120+
121+ def enqueue_jobs (self , jobs : Iterable [Job ], from_failure : bool = False ):
111122 """Enqueue a batch of jobs."""
112123 jobs_to_queue = list ()
113124 for job in jobs :
@@ -124,6 +135,9 @@ def enqueue_jobs(self, jobs: Iterable[Job]):
124135 self ._to_namespaced (RUNNING_JOBS_KEY .format (self ._id )),
125136 self .namespace ,
126137 self ._to_namespaced (FUTURE_JOBS_KEY ),
138+ self ._to_namespaced (MAX_CONCURRENCY_KEY ),
139+ self ._to_namespaced (CURRENT_CONCURRENCY_KEY ),
140+ 1 if from_failure else 0 ,
127141 * jobs_to_queue
128142 )
129143
@@ -188,7 +202,9 @@ def get_jobs_from_queue(self, queue: str, max_jobs: int) -> List[Job]:
188202 self ._to_namespaced (queue ),
189203 self ._to_namespaced (RUNNING_JOBS_KEY .format (self ._id )),
190204 JobStatus .RUNNING .value ,
191- max_jobs
205+ max_jobs ,
206+ self ._to_namespaced (MAX_CONCURRENCY_KEY ),
207+ self ._to_namespaced (CURRENT_CONCURRENCY_KEY ),
192208 )
193209
194210 jobs = json .loads (jobs_json_string .decode ())
@@ -198,10 +214,14 @@ def get_jobs_from_queue(self, queue: str, max_jobs: int) -> List[Job]:
198214
199215 def remove_job_from_running (self , job : Job ):
200216 if job .max_retries > 0 :
201- self ._r .hdel (
217+ self ._run_script (
218+ self ._remove_job_from_running ,
202219 self ._to_namespaced (RUNNING_JOBS_KEY .format (self ._id )),
203- str (job .id )
220+ self ._to_namespaced (MAX_CONCURRENCY_KEY ),
221+ self ._to_namespaced (CURRENT_CONCURRENCY_KEY ),
222+ job .serialize (),
204223 )
224+
205225 self ._something_happened .set ()
206226
207227 def _subscriber_func (self ):
@@ -272,19 +292,38 @@ def enqueue_jobs_from_dead_broker(self, dead_broker_id: uuid.UUID) -> int:
272292 self ._to_namespaced (ALL_BROKERS_HASH_KEY ),
273293 self ._to_namespaced (ALL_BROKERS_ZSET_KEY ),
274294 self .namespace ,
275- self ._to_namespaced (NOTIFICATIONS_KEY )
295+ self ._to_namespaced (NOTIFICATIONS_KEY ),
296+ self ._to_namespaced (MAX_CONCURRENCY_KEY ),
297+ self ._to_namespaced (CURRENT_CONCURRENCY_KEY ),
276298 )
277299
278300 def register_periodic_tasks (self , tasks : Iterable [Task ]):
279301 """Register tasks that need to be scheduled periodically."""
280- tasks = [task .serialize () for task in tasks ]
281- self ._number_periodic_tasks = len (tasks )
302+ _tasks = [task .serialize () for task in tasks ]
303+ self ._number_periodic_tasks = len (_tasks )
282304 self ._run_script (
283305 self ._register_periodic_tasks ,
284306 math .ceil (datetime .now (timezone .utc ).timestamp ()),
285307 self ._to_namespaced (PERIODIC_TASKS_HASH_KEY ),
286308 self ._to_namespaced (PERIODIC_TASKS_QUEUE_KEY ),
287- * tasks
309+ * _tasks
310+ )
311+
312+ def set_concurrency_keys (self , tasks : Iterable [Task ]):
313+ """For each Task, set up its concurrency keys.
314+
315+ The Lua script handles the logic of:
316+ - removing dead keys where a Task was removed
317+ - only setting keys where max_concurrency > 0
318+ """
319+ _tasks = [task .serialize () for task in tasks ]
320+ if not _tasks :
321+ return
322+ self ._run_script (
323+ self ._set_concurrency_keys ,
324+ self ._to_namespaced (MAX_CONCURRENCY_KEY ),
325+ self ._to_namespaced (CURRENT_CONCURRENCY_KEY ),
326+ * _tasks ,
288327 )
289328
290329 def inspect_periodic_tasks (self ) -> List [Tuple [int , str ]]:
0 commit comments