|
1 | 1 | import typing |
2 | 2 |
|
3 | | -from django.conf import settings |
4 | | -from django.db import connections, transaction |
5 | 3 | from django.db.models import Manager |
6 | | -from django.db.utils import ProgrammingError |
7 | | -from django.utils.connection import ConnectionDoesNotExist |
8 | 4 |
|
9 | 5 | if typing.TYPE_CHECKING: |
10 | | - from django.db.models.query import RawQuerySet |
11 | | - |
12 | 6 | from task_processor.models import RecurringTask, Task |
13 | 7 |
|
14 | 8 |
|
15 | 9 | class TaskManager(Manager["Task"]): |
16 | | - def get_tasks_to_process( # noqa: C901 |
| 10 | + def get_tasks_to_process( |
17 | 11 | self, |
| 12 | + database: str, |
18 | 13 | num_tasks: int, |
19 | | - skip_old_database: bool = False, |
20 | | - ) -> typing.Generator["Task", None, None]: |
21 | | - """ |
22 | | - Retrieve tasks to process from the database |
23 | | -
|
24 | | - This does its best effort to retrieve tasks from the old database first |
25 | | - """ |
26 | | - if not skip_old_database: |
27 | | - old_database = "default" if self._is_database_separate else "task_processor" |
28 | | - old_tasks = self._fetch_tasks_from(old_database, num_tasks) |
29 | | - |
30 | | - # Fetch tasks from the previous database |
31 | | - try: |
32 | | - with transaction.atomic(using=old_database): |
33 | | - first_task = next(old_tasks) |
34 | | - except StopIteration: |
35 | | - pass # Empty set |
36 | | - except ProgrammingError: |
37 | | - pass # Function no longer exists in old database |
38 | | - except ConnectionDoesNotExist: |
39 | | - pass # Database not available |
40 | | - else: |
41 | | - yield first_task |
42 | | - num_tasks -= 1 |
43 | | - for task in old_tasks: |
44 | | - yield task |
45 | | - num_tasks -= 1 |
46 | | - |
47 | | - if num_tasks == 0: |
48 | | - return |
49 | | - |
50 | | - new_database = "task_processor" if self._is_database_separate else "default" |
51 | | - new_tasks = self._fetch_tasks_from(new_database, num_tasks) |
52 | | - |
53 | | - # Fetch tasks from the new database |
54 | | - try: |
55 | | - with transaction.atomic(using=new_database): |
56 | | - first_task = next(new_tasks) |
57 | | - except StopIteration: |
58 | | - pass # Empty set |
59 | | - except ProgrammingError: |
60 | | - # Function doesn't exist in the database yet |
61 | | - self._create_or_replace_function__get_tasks_to_process() |
62 | | - yield from self.get_tasks_to_process(num_tasks, skip_old_database=True) |
63 | | - else: |
64 | | - yield first_task |
65 | | - yield from new_tasks |
66 | | - |
67 | | - @property |
68 | | - def _is_database_separate(self) -> bool: |
69 | | - """ |
70 | | - Check whether the task processor database is separate from the default database |
71 | | - """ |
72 | | - return "task_processor" in settings.DATABASES |
73 | | - |
74 | | - def _fetch_tasks_from( |
75 | | - self, database: str, num_tasks: int |
76 | | - ) -> typing.Iterator["Task"]: |
77 | | - """ |
78 | | - Retrieve tasks from the specified Django database |
79 | | - """ |
80 | | - return ( |
81 | | - self.using(database) |
82 | | - .raw("SELECT * FROM get_tasks_to_process(%s)", [num_tasks]) |
83 | | - .iterator() |
| 14 | + ) -> typing.List["Task"]: |
| 15 | + return list( |
| 16 | + self.using(database).raw( |
| 17 | + "SELECT * FROM get_tasks_to_process(%s)", |
| 18 | + [num_tasks], |
| 19 | + ), |
84 | 20 | ) |
85 | 21 |
|
86 | | - def _create_or_replace_function__get_tasks_to_process(self) -> None: |
87 | | - """ |
88 | | - Create or replace the function to get tasks to process. |
89 | | - """ |
90 | | - database = "task_processor" if self._is_database_separate else "default" |
91 | | - with connections[database].cursor() as cursor: |
92 | | - cursor.execute( |
93 | | - """ |
94 | | - CREATE OR REPLACE FUNCTION get_tasks_to_process(num_tasks integer) |
95 | | - RETURNS SETOF task_processor_task AS $$ |
96 | | - DECLARE |
97 | | - row_to_return task_processor_task; |
98 | | - BEGIN |
99 | | - -- Select the tasks that needs to be processed |
100 | | - FOR row_to_return IN |
101 | | - SELECT * |
102 | | - FROM task_processor_task |
103 | | - WHERE num_failures < 3 AND scheduled_for < NOW() AND completed = FALSE AND is_locked = FALSE |
104 | | - ORDER BY priority ASC, scheduled_for ASC, created_at ASC |
105 | | - LIMIT num_tasks |
106 | | - -- Select for update to ensure that no other workers can select these tasks while in this transaction block |
107 | | - FOR UPDATE SKIP LOCKED |
108 | | - LOOP |
109 | | - -- Lock every selected task(by updating `is_locked` to true) |
110 | | - UPDATE task_processor_task |
111 | | - -- Lock this row by setting is_locked True, so that no other workers can select these tasks after this |
112 | | - -- transaction is complete (but the tasks are still being executed by the current worker) |
113 | | - SET is_locked = TRUE |
114 | | - WHERE id = row_to_return.id; |
115 | | - -- If we don't explicitly update the `is_locked` column here, the client will receive the row that is actually locked but has the `is_locked` value set to `False`. |
116 | | - row_to_return.is_locked := TRUE; |
117 | | - RETURN NEXT row_to_return; |
118 | | - END LOOP; |
119 | | -
|
120 | | - RETURN; |
121 | | - END; |
122 | | - $$ LANGUAGE plpgsql |
123 | | - """ |
124 | | - ) |
125 | | - |
126 | 22 |
|
127 | 23 | class RecurringTaskManager(Manager["RecurringTask"]): |
128 | | - def get_tasks_to_process(self) -> "RawQuerySet[RecurringTask]": |
129 | | - return self.raw("SELECT * FROM get_recurringtasks_to_process()") |
| 24 | + def get_tasks_to_process(self, database: str) -> typing.List["RecurringTask"]: |
| 25 | + return list( |
| 26 | + self.using(database).raw("SELECT * FROM get_recurringtasks_to_process()"), |
| 27 | + ) |
0 commit comments