Skip to content

Commit 8002ca6

Browse files
authored
Merge pull request #8 from bigjools/max_concurrency
Add max_concurrency to tasks
2 parents d4510b8 + a10247c commit 8002ca6

21 files changed

+670
-63
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Distinctive features:
1414

1515
- At-least-once or at-most-once delivery per task
1616
- Periodic tasks without an additional process
17+
- Concurrency limits on queued jobs
1718
- Scheduling of tasks in batch
1819
- Integrations with `Flask, Django, Logging, Sentry and Datadog
1920
<https://spinach.readthedocs.io/en/stable/user/integrations.html>`_

doc/hacking/contributing.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,24 @@ The code base follows `pep8 <https://www.python.org/dev/peps/pep-0008/>`_
3232
guidelines with lines wrapping at the 79th character. You can verify that the
3333
code follows the conventions with::
3434

35-
$ pep8 spinach tests
35+
$ pycodestyle --ignore=E252,W503,W504 spinach tests
3636

3737
Running tests is an invaluable help when adding a new feature or when
3838
refactoring. Try to add the proper test cases in ``tests/`` together with your
3939
patch. The test suite can be run with pytest::
4040

4141
$ pytest tests
4242

43+
Because the Redis broker tests require a running Redis server, there is also a
44+
convenience `tox.ini` that runs all the tests and pep8 checks for you after
45+
starting Redis in a container via docker-compose. Simply running::
46+
47+
$ tox
48+
49+
will build a virtualenv, install Spinach and its dependencies into it,
50+
start the Redis server in the container, and run tests and pycodestyle,
51+
tearing down the Redis server container when done.
52+
4353
Compatibility
4454
-------------
4555

doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Distinctive features:
99

1010
- At-least-once or at-most-once delivery per task
1111
- Periodic tasks without an additional process
12+
- Concurrency limits on queued jobs
1213
- Scheduling of tasks in batch
1314
- Embeddable workers for easier testing
1415
- Integrations with :ref:`Flask, Django, Logging, Sentry and Datadog

doc/user/tasks.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,22 @@ A task can also raise a :class:`AbortException` for short-circuit behavior:
109109

110110
.. autoclass:: spinach.task.AbortException
111111

112+
Limiting task concurrency
113+
-------------------------
114+
115+
If a task is idempotent it may also have a limit on the number of
116+
concurrent jobs spawned across all workers. These types of tasks are
117+
defined with a positive `max_concurrency` value::
118+
119+
@tasks.task(name='foo', max_retries=10, max_concurrency=1)
120+
def foo(a, b):
121+
pass
122+
123+
With this definition, no more than one instance of the Task will ever be
124+
spawned as a running Job, no matter how many are queued and waiting to
125+
run.
126+
127+
112128
Periodic tasks
113129
--------------
114130

setup.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@
6767
'flush.lua',
6868
'get_jobs_from_queue.lua',
6969
'move_future_jobs.lua',
70-
'register_periodic_tasks.lua'
70+
'register_periodic_tasks.lua',
71+
'remove_job_from_running.lua',
72+
'set_concurrency_keys.lua',
7173
],
7274
},
7375
)

spinach/brokers/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@ def _to_namespaced(self, value: str) -> str:
8585
def register_periodic_tasks(self, tasks: Iterable[Task]):
8686
"""Register tasks that need to be scheduled periodically."""
8787

88+
@abstractmethod
89+
def set_concurrency_keys(self, tasks: Iterable[Task]):
90+
"""Register concurrency data for Tasks.
91+
92+
Set up anything in the Broker that is required to track
93+
concurrency on Tasks, where a Task defines max_concurrency.
94+
"""
95+
96+
@abstractmethod
97+
def is_queue_empty(self, queue: str) -> bool:
98+
"""Return True if the provided queue is empty."""
99+
88100
@abstractmethod
89101
def inspect_periodic_tasks(self) -> List[Tuple[int, str]]:
90102
"""Get the next periodic task schedule.
@@ -93,7 +105,7 @@ def inspect_periodic_tasks(self) -> List[Tuple[int, str]]:
93105
"""
94106

95107
@abstractmethod
96-
def enqueue_jobs(self, jobs: Iterable[Job]):
108+
def enqueue_jobs(self, jobs: Iterable[Job], from_failure: bool):
97109
"""Enqueue a batch of jobs."""
98110

99111
@abstractmethod

spinach/brokers/memory.py

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ def __init__(self):
2323
self._future_jobs = list()
2424
self._running_jobs = list()
2525
self._scheduler = sched.scheduler()
26+
self._max_concurrency_keys = dict()
27+
self._cur_concurrency_keys = dict()
2628

2729
def _get_queue(self, queue_name: str):
2830
queue_name = self._to_namespaced(queue_name)
@@ -34,9 +36,16 @@ def _get_queue(self, queue_name: str):
3436
self._queues[queue_name] = queue
3537
return queue
3638

37-
def enqueue_jobs(self, jobs: Iterable[Job]):
39+
def enqueue_jobs(self, jobs: Iterable[Job], from_failure: bool=False):
3840
"""Enqueue a batch of jobs."""
3941
for job in jobs:
42+
with self._lock:
43+
if from_failure:
44+
max_concurrency = self._max_concurrency_keys[
45+
job.task_name
46+
]
47+
if max_concurrency is not None:
48+
self._cur_concurrency_keys[job.task_name] -= 1
4049
if job.should_start:
4150
job.status = JobStatus.QUEUED
4251
queue = self._get_queue(job.queue)
@@ -71,6 +80,11 @@ def move_future_jobs(self) -> int:
7180

7281
return num_jobs_moved
7382

83+
def set_concurrency_keys(self, tasks: Iterable[Task]):
84+
for task in tasks:
85+
self._max_concurrency_keys[task.name] = task.max_concurrency
86+
self._cur_concurrency_keys[task.name] = 0
87+
7488
def register_periodic_tasks(self, tasks: Iterable[Task]):
7589
"""Register tasks that need to be scheduled periodically."""
7690
for task in tasks:
@@ -121,18 +135,46 @@ def _get_next_future_job(self) -> Optional[Job]:
121135
except IndexError:
122136
return None
123137

138+
def is_queue_empty(self, queue: str):
139+
return self._get_queue(queue).qsize() == 0
140+
124141
def get_jobs_from_queue(self, queue: str, max_jobs: int) -> List[Job]:
125142
"""Get jobs from a queue."""
126143
rv = list()
127-
while len(rv) < max_jobs:
128-
try:
129-
job_json_string = self._get_queue(queue).get(block=False)
130-
except Empty:
131-
break
132-
133-
job = Job.deserialize(job_json_string)
134-
job.status = JobStatus.RUNNING
135-
rv.append(job)
144+
jobs_to_re_add = list()
145+
with self._lock:
146+
while len(rv) < max_jobs:
147+
try:
148+
job_json_string = self._get_queue(queue).get(block=False)
149+
except Empty:
150+
break
151+
152+
job = Job.deserialize(job_json_string)
153+
max_concurrency = self._max_concurrency_keys.get(job.task_name)
154+
cur_concurrency = self._cur_concurrency_keys.get(job.task_name)
155+
if (
156+
max_concurrency is not None and
157+
cur_concurrency >= max_concurrency
158+
):
159+
jobs_to_re_add.append(job_json_string)
160+
161+
else:
162+
job.status = JobStatus.RUNNING
163+
rv.append(job)
164+
if max_concurrency is not None:
165+
self._cur_concurrency_keys[job.task_name] += 1
166+
167+
# Re-add jobs that could not be run due to max_concurrency
168+
# limits. Queue does not have a way to insert at the front, so
169+
# sadly they go straight to the back again. Given that
170+
# MemoryBroker is generally only used for testing, this should
171+
# not be a great hardship.
172+
logger.debug(
173+
"Re-adding %s jobs due to concurrency limits",
174+
len(jobs_to_re_add)
175+
)
176+
for job in jobs_to_re_add:
177+
self._get_queue(queue).put(job)
136178

137179
return rv
138180

@@ -154,5 +196,13 @@ def remove_job_from_running(self, job: Job):
154196
155197
Easy, the memory broker doesn't track running jobs. If the broker dies
156198
there is nothing we can do.
199+
200+
We still need to decrement the current_concurrency count,
201+
however, if it exists.
157202
"""
203+
with self._lock:
204+
max_concurrency = self._max_concurrency_keys[job.task_name]
205+
if max_concurrency is not None:
206+
self._cur_concurrency_keys[job.task_name] -= 1
207+
158208
self._something_happened.set()

spinach/brokers/redis.py

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from ..const import (
1818
FUTURE_JOBS_KEY, NOTIFICATIONS_KEY, RUNNING_JOBS_KEY,
1919
PERIODIC_TASKS_HASH_KEY, PERIODIC_TASKS_QUEUE_KEY,
20-
DEFAULT_ENQUEUE_JOB_RETRIES, ALL_BROKERS_HASH_KEY, ALL_BROKERS_ZSET_KEY
20+
DEFAULT_ENQUEUE_JOB_RETRIES, ALL_BROKERS_HASH_KEY, ALL_BROKERS_ZSET_KEY,
21+
MAX_CONCURRENCY_KEY, CURRENT_CONCURRENCY_KEY,
2122
)
2223
from ..utils import run_forever, call_with_retry
2324

@@ -57,10 +58,16 @@ def __init__(self, redis: Optional[StrictRedis]=None,
5758
self._get_jobs_from_queue = self._load_script(
5859
'get_jobs_from_queue.lua'
5960
)
61+
self._remove_job_from_running = self._load_script(
62+
'remove_job_from_running.lua'
63+
)
6064
self._move_future_jobs = self._load_script('move_future_jobs.lua')
6165
self._register_periodic_tasks = self._load_script(
6266
'register_periodic_tasks.lua'
6367
)
68+
self._set_concurrency_keys = self._load_script(
69+
'set_concurrency_keys.lua'
70+
)
6471
self._reset()
6572

6673
def _reset(self):
@@ -107,7 +114,11 @@ def _run_script(self, script: Script, *args):
107114

108115
return rv
109116

110-
def enqueue_jobs(self, jobs: Iterable[Job]):
117+
def is_queue_empty(self, queue: str) -> bool:
118+
"""Return True if the provided queue is empty."""
119+
return self._r.llen(self._to_namespaced(queue)) == 0
120+
121+
def enqueue_jobs(self, jobs: Iterable[Job], from_failure: bool=False):
111122
"""Enqueue a batch of jobs."""
112123
jobs_to_queue = list()
113124
for job in jobs:
@@ -124,6 +135,9 @@ def enqueue_jobs(self, jobs: Iterable[Job]):
124135
self._to_namespaced(RUNNING_JOBS_KEY.format(self._id)),
125136
self.namespace,
126137
self._to_namespaced(FUTURE_JOBS_KEY),
138+
self._to_namespaced(MAX_CONCURRENCY_KEY),
139+
self._to_namespaced(CURRENT_CONCURRENCY_KEY),
140+
1 if from_failure else 0,
127141
*jobs_to_queue
128142
)
129143

@@ -188,7 +202,9 @@ def get_jobs_from_queue(self, queue: str, max_jobs: int) -> List[Job]:
188202
self._to_namespaced(queue),
189203
self._to_namespaced(RUNNING_JOBS_KEY.format(self._id)),
190204
JobStatus.RUNNING.value,
191-
max_jobs
205+
max_jobs,
206+
self._to_namespaced(MAX_CONCURRENCY_KEY),
207+
self._to_namespaced(CURRENT_CONCURRENCY_KEY),
192208
)
193209

194210
jobs = json.loads(jobs_json_string.decode())
@@ -198,10 +214,14 @@ def get_jobs_from_queue(self, queue: str, max_jobs: int) -> List[Job]:
198214

199215
def remove_job_from_running(self, job: Job):
200216
if job.max_retries > 0:
201-
self._r.hdel(
217+
self._run_script(
218+
self._remove_job_from_running,
202219
self._to_namespaced(RUNNING_JOBS_KEY.format(self._id)),
203-
str(job.id)
220+
self._to_namespaced(MAX_CONCURRENCY_KEY),
221+
self._to_namespaced(CURRENT_CONCURRENCY_KEY),
222+
job.serialize(),
204223
)
224+
205225
self._something_happened.set()
206226

207227
def _subscriber_func(self):
@@ -272,19 +292,38 @@ def enqueue_jobs_from_dead_broker(self, dead_broker_id: uuid.UUID) -> int:
272292
self._to_namespaced(ALL_BROKERS_HASH_KEY),
273293
self._to_namespaced(ALL_BROKERS_ZSET_KEY),
274294
self.namespace,
275-
self._to_namespaced(NOTIFICATIONS_KEY)
295+
self._to_namespaced(NOTIFICATIONS_KEY),
296+
self._to_namespaced(MAX_CONCURRENCY_KEY),
297+
self._to_namespaced(CURRENT_CONCURRENCY_KEY),
276298
)
277299

278300
def register_periodic_tasks(self, tasks: Iterable[Task]):
279301
"""Register tasks that need to be scheduled periodically."""
280-
tasks = [task.serialize() for task in tasks]
281-
self._number_periodic_tasks = len(tasks)
302+
_tasks = [task.serialize() for task in tasks]
303+
self._number_periodic_tasks = len(_tasks)
282304
self._run_script(
283305
self._register_periodic_tasks,
284306
math.ceil(datetime.now(timezone.utc).timestamp()),
285307
self._to_namespaced(PERIODIC_TASKS_HASH_KEY),
286308
self._to_namespaced(PERIODIC_TASKS_QUEUE_KEY),
287-
*tasks
309+
*_tasks
310+
)
311+
312+
def set_concurrency_keys(self, tasks: Iterable[Task]):
313+
"""For each Task, set up its concurrency keys.
314+
315+
The Lua script handles the logic of:
316+
- removing dead keys where a Task was removed
317+
- only setting keys where max_concurrency > 0
318+
"""
319+
_tasks = [task.serialize() for task in tasks]
320+
if not _tasks:
321+
return
322+
self._run_script(
323+
self._set_concurrency_keys,
324+
self._to_namespaced(MAX_CONCURRENCY_KEY),
325+
self._to_namespaced(CURRENT_CONCURRENCY_KEY),
326+
*_tasks,
288327
)
289328

290329
def inspect_periodic_tasks(self) -> List[Tuple[int, str]]:

spinach/brokers/redis_scripts/enqueue_job.lua

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,27 @@ local notifications = ARGV[2]
44
local running_jobs_key = ARGV[3]
55
local namespace = ARGV[4]
66
local future_jobs = ARGV[5]
7-
-- jobs starting at ARGV[6]
7+
local max_concurrency_key = ARGV[6]
8+
local current_concurrency_key = ARGV[7]
9+
local from_failure = ARGV[8]
10+
11+
-- jobs starting at ARGV[9]
812

913
if not redis.call('set', idempotency_token, 'true', 'EX', 3600, 'NX') then
1014
redis.log(redis.LOG_WARNING, "Not reprocessing script")
1115
return -1
1216
end
1317

14-
for i=6, #ARGV do
18+
for i=9, #ARGV do
1519
local job_json = ARGV[i]
1620
local job = cjson.decode(job_json)
21+
if tonumber(from_failure) == 1 then
22+
-- job is being requeued after a failure, decrement its concurrency
23+
local max_concurrency = tonumber(redis.call('hget', max_concurrency_key, job['task_name']))
24+
if max_concurrency ~= nil and max_concurrency ~= -1 then
25+
redis.call('hincrby', current_concurrency_key, job['task_name'], -1)
26+
end
27+
end
1728
if job["status"] == 2 then
1829
-- job status is queued
1930
local queue = string.format("%s/%s", namespace, job["queue"])

spinach/brokers/redis_scripts/enqueue_jobs_from_dead_broker.lua

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ local all_brokers_hash_key = ARGV[3]
44
local all_brokers_zset_key = ARGV[4]
55
local namespace = ARGV[5]
66
local notifications = ARGV[6]
7+
local max_concurrency_key = ARGV[7]
8+
local current_concurrency_key = ARGV[8]
79

810
local num_enqueued_jobs = 0
911

@@ -26,6 +28,13 @@ for _, job_json in ipairs(jobs_json) do
2628
-- Serialize the job so that it can be put in the queue
2729
local job_json = cjson.encode(job)
2830

31+
-- Decrement the current concurrency if we are tracking
32+
-- concurrency on the Task.
33+
local max_concurrency = tonumber(redis.call('hget', max_concurrency_key, job['task_name']))
34+
if max_concurrency ~= nil and max_concurrency ~= -1 then
35+
redis.call('hincrby', current_concurrency_key, job['task_name'], -1)
36+
end
37+
2938
-- Enqueue the job
3039
local queue = string.format("%s/%s", namespace, job["queue"])
3140
redis.call('rpush', queue, job_json)

0 commit comments

Comments
 (0)