Skip to content

Commit 0607086

Browse files
committed
WIP: New PoC for processing tasks from multiple databases
1 parent 8624614 commit 0607086

File tree

4 files changed

+74
-53
lines changed

4 files changed

+74
-53
lines changed

src/task_processor/managers.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,25 @@
33
from django.db.models import Manager
44

55
if typing.TYPE_CHECKING:
6-
from django.db.models.query import RawQuerySet
7-
86
from task_processor.models import RecurringTask, Task
97

108

119
class TaskManager(Manager["Task"]):
12-
def get_tasks_to_process(self, num_tasks: int) -> "RawQuerySet[Task]":
13-
return self.raw("SELECT * FROM get_tasks_to_process(%s)", [num_tasks])
10+
def get_tasks_to_process(
11+
self,
12+
database: str,
13+
num_tasks: int,
14+
) -> typing.List["Task"]:
15+
return list(
16+
self.using(database).raw(
17+
"SELECT * FROM get_tasks_to_process(%s)",
18+
[num_tasks],
19+
),
20+
)
1421

1522

1623
class RecurringTaskManager(Manager["RecurringTask"]):
17-
def get_tasks_to_process(self) -> "RawQuerySet[RecurringTask]":
18-
return self.raw("SELECT * FROM get_recurringtasks_to_process()")
24+
def get_tasks_to_process(self, database: str) -> typing.List["RecurringTask"]:
25+
return list(
26+
self.using(database).raw("SELECT * FROM get_recurringtasks_to_process()"),
27+
)

src/task_processor/processor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@
2727
UNREGISTERED_RECURRING_TASK_GRACE_PERIOD = timedelta(minutes=30)
2828

2929

30-
def run_tasks(num_tasks: int = 1) -> list[TaskRun]:
30+
def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
3131
if num_tasks < 1:
3232
raise ValueError("Number of tasks to process must be at least one")
3333

34-
tasks = Task.objects.get_tasks_to_process(num_tasks)
34+
tasks = list(Task.objects.get_tasks_to_process(database, num_tasks))
3535

3636
if tasks:
37-
logger.debug(f"Running {len(tasks)} task(s)")
37+
logger.debug(f"Running {len(tasks)} task(s) from database '{database}'")
3838

3939
executed_tasks = []
4040
task_runs = []
@@ -54,20 +54,24 @@ def run_tasks(num_tasks: int = 1) -> list[TaskRun]:
5454

5555
if task_runs:
5656
TaskRun.objects.bulk_create(task_runs)
57-
logger.debug(f"Finished running {len(task_runs)} task(s)")
57+
logger.debug(
58+
f"Finished running {len(task_runs)} task(s) from database '{database}'",
59+
)
5860

5961
return task_runs
6062

6163
return []
6264

6365

64-
def run_recurring_tasks() -> list[RecurringTaskRun]:
66+
def run_recurring_tasks(database: str) -> list[RecurringTaskRun]:
6567
# NOTE: We will probably see a lot of delay in the execution of recurring tasks
6668
# if the tasks take longer then `run_every` to execute. This is not
6769
# a problem for now, but we should be mindful of this limitation
68-
tasks = RecurringTask.objects.get_tasks_to_process()
70+
tasks = RecurringTask.objects.get_tasks_to_process(database)
6971
if tasks:
70-
logger.debug(f"Running {len(tasks)} recurring task(s)")
72+
logger.debug(
73+
f"Running {len(tasks)} recurring task(s) from database '{database}'",
74+
)
7175

7276
task_runs = []
7377

@@ -95,7 +99,9 @@ def run_recurring_tasks() -> list[RecurringTaskRun]:
9599

96100
if task_runs:
97101
RecurringTaskRun.objects.bulk_create(task_runs)
98-
logger.debug(f"Finished running {len(task_runs)} recurring task(s)")
102+
logger.debug(
103+
f"Finished running {len(task_runs)} recurring task(s) from database '{database}'",
104+
)
99105

100106
return task_runs
101107

src/task_processor/threads.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,24 @@ def run(self) -> None:
9595
time.sleep(self.sleep_interval_millis / 1000)
9696

9797
def run_iteration(self) -> None:
98-
try:
99-
run_tasks(self.queue_pop_size)
100-
run_recurring_tasks()
101-
except Exception as e:
102-
# To prevent task threads from dying if they get an error retrieving the tasks from the
103-
# database this will allow the thread to continue trying to retrieve tasks if it can
104-
# successfully re-establish a connection to the database.
105-
# TODO: is this also what is causing tasks to get stuck as locked? Can we unlock
106-
# tasks here?
107-
108-
logger.error("Received error retrieving tasks: %s.", e, exc_info=e)
109-
close_old_connections()
98+
for database in ["default", "task_processor"]:
99+
try:
100+
run_tasks(database, self.queue_pop_size)
101+
run_recurring_tasks(database)
102+
except Exception as exception:
103+
# To prevent task threads from dying if they get an error retrieving the tasks from the
104+
# database this will allow the thread to continue trying to retrieve tasks if it can
105+
# successfully re-establish a connection to the database.
106+
# TODO: is this also what is causing tasks to get stuck as locked? Can we unlock
107+
# tasks here?
108+
109+
logger.error(
110+
"Received error retrieving tasks from database '%s': %s.",
111+
database,
112+
repr(exception),
113+
exc_info=exception,
114+
)
115+
close_old_connections()
110116

111117
def stop(self) -> None:
112118
self._stopped = True

tests/unit/task_processor/test_unit_task_processor_processor.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_success(
8585
task.save()
8686

8787
# When
88-
task_runs = run_tasks()
88+
task_runs = run_tasks("default")
8989

9090
# Then
9191
assert cache.get(DEFAULT_CACHE_KEY)
@@ -116,7 +116,7 @@ def test_run_task_kills_task_after_timeout(
116116
task.save()
117117

118118
# When
119-
task_runs = run_tasks()
119+
task_runs = run_tasks("default")
120120

121121
# Then
122122
assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1
@@ -157,7 +157,7 @@ def _dummy_recurring_task() -> None:
157157
task_identifier="test_unit_task_processor_processor._dummy_recurring_task",
158158
)
159159
# When
160-
task_runs = run_recurring_tasks()
160+
task_runs = run_recurring_tasks("default")
161161

162162
# Then
163163
assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1
@@ -195,7 +195,7 @@ def _dummy_recurring_task() -> None:
195195
task_identifier="test_unit_task_processor_processor._dummy_recurring_task",
196196
)
197197
# When
198-
task_runs = run_recurring_tasks()
198+
task_runs = run_recurring_tasks("default")
199199

200200
# Then
201201
assert cache.get(DEFAULT_CACHE_KEY)
@@ -227,7 +227,7 @@ def _dummy_recurring_task() -> None:
227227

228228
# When
229229
assert cache.get(DEFAULT_CACHE_KEY) is None
230-
task_runs = run_recurring_tasks()
230+
task_runs = run_recurring_tasks("default")
231231

232232
# Then
233233
assert cache.get(DEFAULT_CACHE_KEY) == DEFAULT_CACHE_VALUE
@@ -261,16 +261,16 @@ def _dummy_recurring_task() -> None:
261261
)
262262

263263
# When
264-
first_task_runs = run_recurring_tasks()
264+
first_task_runs = run_recurring_tasks("default")
265265

266266
# run the process again before the task is scheduled to run again to ensure
267267
# that tasks are unlocked when they are picked up by the task processor but
268268
# not executed.
269-
no_task_runs = run_recurring_tasks()
269+
no_task_runs = run_recurring_tasks("default")
270270

271271
time.sleep(0.3)
272272

273-
second_task_runs = run_recurring_tasks()
273+
second_task_runs = run_recurring_tasks("default")
274274

275275
# Then
276276
assert len(first_task_runs) == 1
@@ -310,7 +310,7 @@ def _dummy_recurring_task_3() -> None:
310310

311311
# When, we call run_recurring_tasks in a loop few times
312312
for _ in range(4):
313-
run_recurring_tasks()
313+
run_recurring_tasks("default")
314314

315315
# Then - we should have exactly one RecurringTaskRun for each task
316316
for i in range(1, 4):
@@ -339,8 +339,8 @@ def _dummy_recurring_task() -> None:
339339
)
340340

341341
# When - we call run_recurring_tasks twice
342-
run_recurring_tasks()
343-
run_recurring_tasks()
342+
run_recurring_tasks("default")
343+
run_recurring_tasks("default")
344344

345345
# Then - we expect the task to have been run once
346346

@@ -367,7 +367,7 @@ def _a_task() -> None:
367367
registered_tasks.pop(task_identifier)
368368

369369
# When
370-
task_runs = run_recurring_tasks()
370+
task_runs = run_recurring_tasks("default")
371371

372372
# Then
373373
assert len(task_runs) == 0
@@ -395,7 +395,7 @@ def _a_task() -> None:
395395
registered_tasks.pop(task_identifier)
396396

397397
# When
398-
task_runs = run_recurring_tasks()
398+
task_runs = run_recurring_tasks("default")
399399

400400
# Then
401401
assert len(task_runs) == 0
@@ -418,7 +418,7 @@ def test_run_task_runs_task_and_creates_task_run_object_when_failure(
418418
task.save()
419419

420420
# When
421-
task_runs = run_tasks()
421+
task_runs = run_tasks("default")
422422

423423
# Then
424424
assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1
@@ -460,10 +460,10 @@ def test_run_task_runs_failed_task_again(
460460
task.save()
461461

462462
# When
463-
first_task_runs = run_tasks()
463+
first_task_runs = run_tasks("default")
464464

465465
# Now, let's run the task again
466-
second_task_runs = run_tasks()
466+
second_task_runs = run_tasks("default")
467467

468468
# Then
469469
task_runs = first_task_runs + second_task_runs
@@ -498,7 +498,7 @@ def _raise_exception(organisation_name: str) -> None:
498498
task = RecurringTask.objects.get(task_identifier=task_identifier)
499499

500500
# When
501-
task_runs = run_recurring_tasks()
501+
task_runs = run_recurring_tasks("default")
502502

503503
# Then
504504
assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1
@@ -513,7 +513,7 @@ def _raise_exception(organisation_name: str) -> None:
513513
def test_run_task_does_nothing_if_no_tasks() -> None:
514514
# Given - no tasks
515515
# When
516-
result = run_tasks()
516+
result = run_tasks("default")
517517
# Then
518518
assert result == []
519519
assert not TaskRun.objects.exists()
@@ -551,9 +551,9 @@ def test_run_task_runs_tasks_in_correct_priority(
551551
task_3.save()
552552

553553
# When
554-
task_runs_1 = run_tasks()
555-
task_runs_2 = run_tasks()
556-
task_runs_3 = run_tasks()
554+
task_runs_1 = run_tasks("default")
555+
task_runs_2 = run_tasks("default")
556+
task_runs_3 = run_tasks("default")
557557

558558
# Then
559559
assert task_runs_1[0].task == task_3
@@ -573,7 +573,7 @@ def test_run_tasks__fails_if_not_in_task_processor_mode(
573573

574574
# When
575575
with pytest.raises(AssertionError):
576-
run_tasks()
576+
run_tasks("default")
577577

578578

579579
@pytest.mark.django_db(transaction=True)
@@ -609,8 +609,8 @@ def _fake_recurring_task() -> None:
609609
).save()
610610

611611
# When
612-
run_tasks(2)
613-
run_recurring_tasks()
612+
run_tasks("default", 2)
613+
run_recurring_tasks("default")
614614

615615
# Then
616616
assert_metric(
@@ -703,7 +703,7 @@ def test_run_tasks_skips_locked_tasks(
703703

704704
# and subsequently attempt to run another task in the main thread
705705
time.sleep(1) # wait for the thread to start and hold the task
706-
task_runs = run_tasks()
706+
task_runs = run_tasks("default")
707707

708708
# Then
709709
# the second task is run while the 1st task is held
@@ -730,7 +730,7 @@ def test_run_more_than_one_task(dummy_task: TaskHandler[[str, str]]) -> None:
730730
Task.objects.bulk_create(tasks)
731731

732732
# When
733-
task_runs = run_tasks(5)
733+
task_runs = run_tasks("default", 5)
734734

735735
# Then
736736
assert len(task_runs) == num_tasks
@@ -772,7 +772,7 @@ def my_task() -> None:
772772
)
773773

774774
# When
775-
run_recurring_tasks()
775+
run_recurring_tasks("default")
776776

777777
# Then
778778
recurring_task.refresh_from_db()

0 commit comments

Comments
 (0)