Skip to content

Commit 6519b76

Browse files
committed
WIP
1 parent ecce903 commit 6519b76

File tree

3 files changed

+280
-3
lines changed

3 files changed

+280
-3
lines changed

src/task_processor/managers.py

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import typing
22

3+
from django.conf import settings
4+
from django.db import connections, transaction
35
from django.db.models import Manager
6+
from django.db.utils import ProgrammingError
7+
from django.utils.connection import ConnectionDoesNotExist
48

59
if typing.TYPE_CHECKING:
610
from django.db.models.query import RawQuerySet
@@ -9,8 +13,115 @@
913

1014

1115
class TaskManager(Manager["Task"]):
12-
def get_tasks_to_process(self, num_tasks: int) -> "RawQuerySet[Task]":
13-
return self.raw("SELECT * FROM get_tasks_to_process(%s)", [num_tasks])
16+
def get_tasks_to_process( # noqa: C901
17+
self,
18+
num_tasks: int,
19+
skip_old_database: bool = False,
20+
) -> typing.Generator["Task", None, None]:
21+
"""
22+
Retrieve tasks to process from the database
23+
24+
This does its best effort to retrieve tasks from the old database first
25+
"""
26+
if not skip_old_database:
27+
old_database = "default" if self._is_database_separate else "task_processor"
28+
old_tasks = self._fetch_tasks_from(old_database, num_tasks)
29+
30+
# Fetch tasks from the previous database
31+
try:
32+
with transaction.atomic(using=old_database):
33+
first_task = next(old_tasks)
34+
except StopIteration:
35+
pass # Empty set
36+
except ProgrammingError:
37+
pass # Function no longer exists in old database
38+
except ConnectionDoesNotExist:
39+
pass # Database not available
40+
else:
41+
yield first_task
42+
num_tasks -= 1
43+
for task in old_tasks:
44+
yield task
45+
num_tasks -= 1
46+
47+
if num_tasks == 0:
48+
return
49+
50+
new_database = "task_processor" if self._is_database_separate else "default"
51+
new_tasks = self._fetch_tasks_from(new_database, num_tasks)
52+
53+
# Fetch tasks from the new database
54+
try:
55+
with transaction.atomic(using=new_database):
56+
first_task = next(new_tasks)
57+
except StopIteration:
58+
pass # Empty set
59+
except ProgrammingError:
60+
# Function doesn't exist in the database yet
61+
self._create_or_replace_function__get_tasks_to_process()
62+
yield from self.get_tasks_to_process(num_tasks, skip_old_database=True)
63+
else:
64+
yield first_task
65+
yield from new_tasks
66+
67+
@property
68+
def _is_database_separate(self) -> bool:
69+
"""
70+
Check whether the task processor database is separate from the default database
71+
"""
72+
return "task_processor" in settings.DATABASES
73+
74+
def _fetch_tasks_from(
75+
self, database: str, num_tasks: int
76+
) -> typing.Iterator["Task"]:
77+
"""
78+
Retrieve tasks from the specified Django database
79+
"""
80+
return (
81+
self.using(database)
82+
.raw("SELECT * FROM get_tasks_to_process(%s)", [num_tasks])
83+
.iterator()
84+
)
85+
86+
def _create_or_replace_function__get_tasks_to_process(self) -> None:
87+
"""
88+
Create or replace the function to get tasks to process.
89+
"""
90+
database = "task_processor" if self._is_database_separate else "default"
91+
with connections[database].cursor() as cursor:
92+
cursor.execute(
93+
"""
94+
CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer)
95+
RETURNS SETOF task_processor_task AS $$
96+
DECLARE
97+
row_to_return task_processor_task;
98+
BEGIN
99+
-- Select the tasks that needs to be processed
100+
FOR row_to_return IN
101+
SELECT *
102+
FROM task_processor_task
103+
WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND is_locked = FALSE
104+
ORDER BY priority ASC, scheduled_for ASC, created_at ASC
105+
LIMIT num_tasks
106+
-- Select for update to ensure that no other workers can select these tasks while in this transaction block
107+
FOR UPDATE SKIP LOCKED
108+
LOOP
109+
-- Lock every selected task(by updating `is_locked` to true)
110+
UPDATE task_processor_task
111+
-- Lock this row by setting is_locked True, so that no other workers can select these tasks after this
112+
-- transaction is complete (but the tasks are still being executed by the current worker)
113+
SET is_locked = TRUE
114+
WHERE id = row_to_return.id;
115+
-- If we don't explicitly update the `is_locked` column here, the client will receive the row that is actually locked but has the `is_locked` value set to `False`.
116+
row_to_return.is_locked := TRUE;
117+
RETURN NEXT row_to_return;
118+
END LOOP;
119+
120+
RETURN;
121+
END;
122+
$$ LANGUAGE plpgsql
123+
"""
124+
)
14125

15126

16127
class RecurringTaskManager(Manager["RecurringTask"]):

src/task_processor/processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def run_tasks(num_tasks: int = 1) -> list[TaskRun]:
3131
if num_tasks < 1:
3232
raise ValueError("Number of tasks to process must be at least one")
3333

34-
tasks = Task.objects.get_tasks_to_process(num_tasks)
34+
tasks = list(Task.objects.get_tasks_to_process(num_tasks))
3535

3636
if tasks:
3737
logger.debug(f"Running {len(tasks)} task(s)")
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from collections.abc import Callable
2+
from datetime import timedelta
3+
from typing import Generator
4+
5+
import pytest
6+
from django.db.utils import ProgrammingError
7+
from django.utils import timezone
8+
from pytest_django.fixtures import SettingsWrapper
9+
from pytest_mock import MockerFixture
10+
11+
from task_processor import managers
12+
from task_processor.models import Task
13+
14+
now = timezone.now()
15+
one_hour_ago = now - timedelta(hours=1)
16+
17+
pytestmark = pytest.mark.django_db
18+
19+
20+
@pytest.fixture
21+
def task_processor_database_tasks(
22+
mocker: MockerFixture,
23+
settings: SettingsWrapper,
24+
) -> Generator[Callable[..., list[Task]], None, None]:
25+
"""
26+
Prepare a fake tasks processor database
27+
"""
28+
29+
def patch(*tasks: Task) -> list[Task]:
30+
def new_fetch_tasks(
31+
self: managers.TaskManager,
32+
database: str,
33+
num_tasks: int,
34+
) -> Generator[Task, None, None]:
35+
if database == "task_processor":
36+
yield from tasks
37+
return
38+
yield from orig_fetch_tasks(self, database, num_tasks)
39+
40+
orig_fetch_tasks = managers.TaskManager._fetch_tasks_from
41+
mocker.patch.object(
42+
managers.TaskManager, "_fetch_tasks_from", new=new_fetch_tasks
43+
)
44+
return list(tasks)
45+
46+
mocker.patch.object(managers, "transaction")
47+
settings.DATABASES["task_processor"] = {"not": "undefined"}
48+
yield patch
49+
del settings.DATABASES["task_processor"]
50+
51+
52+
@pytest.fixture
53+
def missing_database_function(
54+
mocker: MockerFixture,
55+
) -> Generator[Callable[..., list[Task]], None, None]:
56+
"""
57+
Pretends the `get_tasks_to_process` database function does not exist
58+
"""
59+
60+
def patch(target_database: str, exception: Exception) -> None:
61+
class BrokenIterator:
62+
def __next__(self):
63+
raise exception
64+
65+
def new_fetch_tasks(
66+
self: managers.TaskManager,
67+
database: str,
68+
num_tasks: int,
69+
) -> Generator[Task, None, None]:
70+
if target_database == database:
71+
yield from BrokenIterator()
72+
yield from orig_fetch_tasks(self, database, num_tasks)
73+
74+
orig_fetch_tasks = managers.TaskManager._fetch_tasks_from
75+
mocker.patch.object(
76+
managers.TaskManager, "_fetch_tasks_from", new=new_fetch_tasks
77+
)
78+
79+
return patch
80+
81+
82+
def test_TaskManager__default_database__consumes_tasks_from_database(
83+
mocker: MockerFixture,
84+
) -> None:
85+
# Given
86+
mocker.patch.object(managers.TaskManager, "_is_database_separate", new=False)
87+
new_tasks = [
88+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
89+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
90+
]
91+
Task.objects.bulk_create(new_tasks)
92+
93+
# When
94+
tasks = Task.objects.get_tasks_to_process(1)
95+
96+
# Then
97+
assert list(tasks) == new_tasks[:1]
98+
99+
100+
def test_TaskManager__default_database__task_processor_database_with_pending_tasks__consumes_tasks_from_old_database_first(
101+
mocker: MockerFixture,
102+
task_processor_database_tasks: Callable[..., list[Task]],
103+
) -> None:
104+
# Given
105+
mocker.patch.object(managers.TaskManager, "_is_database_separate", new=False)
106+
old_tasks = task_processor_database_tasks(
107+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
108+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
109+
)
110+
new_tasks = [
111+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
112+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
113+
]
114+
Task.objects.bulk_create(new_tasks)
115+
116+
# When
117+
tasks = Task.objects.get_tasks_to_process(3)
118+
119+
# Then
120+
assert list(tasks) == old_tasks + new_tasks[:1]
121+
122+
123+
def test_TaskManager__default_database__task_processor_database_with_no_pending_tasks__consumes_tasks_from_new_database(
124+
mocker: MockerFixture,
125+
task_processor_database_tasks: Callable[..., list[Task]],
126+
) -> None:
127+
# Given
128+
mocker.patch.object(managers.TaskManager, "_is_database_separate", new=False)
129+
task_processor_database_tasks() # None
130+
new_tasks = [
131+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
132+
Task(task_identifier="some.identifier", scheduled_for=one_hour_ago),
133+
]
134+
Task.objects.bulk_create(new_tasks)
135+
136+
# When
137+
tasks = Task.objects.get_tasks_to_process(3)
138+
139+
# Then
140+
assert list(tasks) == new_tasks
141+
142+
143+
def test_TaskManager__default_database__task_processor_database_missing_function__consumes_tasks_from_default_database(
144+
mocker: MockerFixture,
145+
) -> None:
146+
# Given
147+
mocker.patch.object(managers.TaskManager, "_is_database_separate", new=False)
148+
raise
149+
150+
151+
def test_TaskManager__task_processor_database__default_database_with_pending_tasks__consumes_tasks_from_old_database_first(
152+
mocker: MockerFixture,
153+
) -> None:
154+
raise
155+
156+
157+
def test_TaskManager__task_processor_database__default_database_with_no_pending_tasks__consumes_tasks_from_new_database(
158+
mocker: MockerFixture,
159+
) -> None:
160+
raise
161+
162+
163+
def test_TaskManager__task_processor_database__default_database_unaware_of_tasks__consumes_tasks_from_new_database(
164+
mocker: MockerFixture,
165+
) -> None:
166+
raise

0 commit comments

Comments
 (0)