Skip to content

Commit ec6c2e5

Browse files
committed
WIP: Fix managing tasks when in multi-database mode
1 parent 1f95cc0 commit ec6c2e5

File tree

9 files changed

+228
-37
lines changed

9 files changed

+228
-37
lines changed

docker/docker-compose.local.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ name: flagsmith
44

55
volumes:
66
pg_data:
7+
task_processor_pg_data:
78

89
services:
910
db:
1011
image: postgres:15.5-alpine
11-
pull_policy: always
1212
restart: unless-stopped
1313
volumes:
1414
- pg_data:/var/lib/postgresql/data
@@ -17,3 +17,21 @@ services:
1717
environment:
1818
POSTGRES_DB: flagsmith
1919
POSTGRES_PASSWORD: password
20+
healthcheck:
21+
test: pg_isready -Upostgres
22+
interval: 1s
23+
timeout: 30s
24+
25+
task-processor-db:
26+
image: postgres:15.5-alpine
27+
restart: unless-stopped
28+
volumes:
29+
- task_processor_pg_data:/var/lib/postgresql/data
30+
ports:
31+
- 5433:5432
32+
environment:
33+
POSTGRES_HOST_AUTH_METHOD: trust
34+
healthcheck:
35+
test: pg_isready -Upostgres
36+
interval: 1s
37+
timeout: 30s

settings/dev.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,17 @@
2929
env(
3030
"DATABASE_URL",
3131
default="postgresql://postgres:password@localhost:5432/flagsmith",
32-
)
33-
)
32+
),
33+
),
34+
"task_processor": dj_database_url.parse(
35+
env(
36+
"TASK_PROCESSOR_DATABASE_URL",
37+
default="postgresql://postgres@localhost:5433/postgres",
38+
),
39+
),
3440
}
41+
DATABASE_ROUTERS = ["task_processor.routers.TaskProcessorRouter"]
42+
TASK_PROCESSOR_DATABASES = ["default"]
3543
INSTALLED_APPS = [
3644
"django.contrib.auth",
3745
"django.contrib.contenttypes",
@@ -62,5 +70,3 @@
6270

6371
# Avoid models.W042 warnings
6472
DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
65-
66-
TASK_PROCESSOR_DATABASES = ["default"]

src/task_processor/managers.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@
77

88

99
class TaskManager(Manager["Task"]):
10-
def get_tasks_to_process(
11-
self,
12-
database: str,
13-
num_tasks: int,
14-
) -> typing.List["Task"]:
10+
def get_tasks_to_process(self, num_tasks: int) -> typing.List["Task"]:
1511
return list(
16-
self.using(database).raw(
12+
self.raw(
1713
"SELECT * FROM get_tasks_to_process(%s)",
1814
[num_tasks],
1915
),
2016
)
2117

2218

2319
class RecurringTaskManager(Manager["RecurringTask"]):
24-
def get_tasks_to_process(self, database: str) -> typing.List["RecurringTask"]:
20+
def get_tasks_to_process(self) -> typing.List["RecurringTask"]:
2521
return list(
26-
self.using(database).raw("SELECT * FROM get_recurringtasks_to_process()"),
22+
self.raw("SELECT * FROM get_recurringtasks_to_process()"),
2723
)

src/task_processor/processor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from django.utils import timezone
1010

1111
from task_processor import metrics
12+
from task_processor.managers import TaskManager
1213
from task_processor.models import (
1314
AbstractBaseTask,
1415
RecurringTask,
@@ -31,7 +32,8 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
3132
if num_tasks < 1:
3233
raise ValueError("Number of tasks to process must be at least one")
3334

34-
tasks = Task.objects.get_tasks_to_process(database, num_tasks)
35+
task_manager: TaskManager = Task.objects.db_manager(database)
36+
tasks = task_manager.get_tasks_to_process(num_tasks)
3537
if tasks:
3638
logger.debug(f"Running {len(tasks)} task(s) from database '{database}'")
3739

@@ -46,13 +48,13 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
4648
task_runs.append(task_run)
4749

4850
if executed_tasks:
49-
Task.objects.bulk_update(
51+
Task.objects.using(database).bulk_update(
5052
executed_tasks,
5153
fields=["completed", "num_failures", "is_locked"],
5254
)
5355

5456
if task_runs:
55-
TaskRun.objects.bulk_create(task_runs)
57+
TaskRun.objects.using(database).bulk_create(task_runs)
5658
logger.debug(
5759
f"Finished running {len(task_runs)} task(s) from database '{database}'"
5860
)
@@ -66,7 +68,7 @@ def run_recurring_tasks(database: str) -> list[RecurringTaskRun]:
6668
# NOTE: We will probably see a lot of delay in the execution of recurring tasks
6769
# if the tasks take longer then `run_every` to execute. This is not
6870
# a problem for now, but we should be mindful of this limitation
69-
tasks = RecurringTask.objects.get_tasks_to_process(database)
71+
tasks = RecurringTask.objects.db_manager(database).get_tasks_to_process()
7072
if tasks:
7173
logger.debug(f"Running {len(tasks)} recurring task(s)")
7274

src/task_processor/routers.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from django.conf import settings
2+
from django.db.models import Model
3+
4+
5+
class TaskProcessorRouter:
6+
"""
7+
Routing of database operations for task processor models
8+
"""
9+
10+
route_app_labels = ["task_processor"]
11+
12+
@property
13+
def is_enabled(self) -> bool:
14+
return "task_processor" in settings.TASK_PROCESSOR_DATABASES
15+
16+
def db_for_read(self, model: type[Model], **hints: None) -> str | None:
17+
"""
18+
If enabled, route "task_processor" models to the a se database
19+
"""
20+
if not self.is_enabled:
21+
return None
22+
23+
if model._meta.app_label in self.route_app_labels:
24+
return "task_processor"
25+
26+
return None
27+
28+
def db_for_write(self, model: type[Model], **hints: None) -> str | None:
29+
"""
30+
Attempts to write task processor models go to 'task_processor' database.
31+
"""
32+
if not self.is_enabled:
33+
return None
34+
35+
if model._meta.app_label in self.route_app_labels:
36+
return "task_processor"
37+
38+
return None
39+
40+
def allow_relation(self, obj1: Model, obj2: Model, **hints: None) -> bool | None:
41+
"""
42+
Relations between objects are allowed if both objects are
43+
in the task processor database.
44+
"""
45+
if not self.is_enabled:
46+
return None
47+
48+
both_objects_from_task_processor = (
49+
obj1._meta.app_label in self.route_app_labels
50+
and obj2._meta.app_label in self.route_app_labels
51+
)
52+
53+
if both_objects_from_task_processor:
54+
return True
55+
56+
return None
57+
58+
def allow_migrate(
59+
self,
60+
db: str,
61+
app_label: str,
62+
**hints: None,
63+
) -> bool | None:
64+
"""
65+
Allow migrations for task processor models to run in both databases
66+
67+
NOTE: Even if, from a fresh install, the task processor tables are not
68+
required in both databases, this is required to allow for easier
69+
transition between a single database and a multi-database setup.
70+
"""
71+
if not self.is_enabled:
72+
return None
73+
74+
if app_label in self.route_app_labels:
75+
return db in ["default", "task_processor"]
76+
77+
return None

src/task_processor/threads.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,18 @@ def run_iteration(self) -> None:
104104
separate database setup.
105105
"""
106106
database_is_separate = "task_processor" in settings.TASK_PROCESSOR_DATABASES
107+
if database_is_separate:
108+
assert (
109+
"task_processor.routers.TaskProcessorRouter"
110+
in settings.DATABASE_ROUTERS
111+
), (
112+
"DATABASE_ROUTERS must include 'task_processor.routers.TaskProcessorRouter' "
113+
"when using a separate task processor database."
114+
) # This is for our own sanity
115+
assert "task_processor" in settings.DATABASES, (
116+
"DATABASES must include 'task_processor' when using a separate task processor database."
117+
) # ¯\_(ツ)_/¯ One has to read the documentation and fix it: https://docs.flagsmith.com/deployment/configuration/task-processor
118+
107119
for database in settings.TASK_PROCESSOR_DATABASES:
108120
try:
109121
run_tasks(database, self.queue_pop_size)

tests/unit/task_processor/test_unit_task_processor_processor.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import time
3-
import typing
43
import uuid
54
from datetime import timedelta
65
from threading import Thread
@@ -9,6 +8,7 @@
98
from django.core.cache import cache
109
from django.utils import timezone
1110
from freezegun import freeze_time
11+
from pytest_django.fixtures import SettingsWrapper
1212
from pytest_mock import MockerFixture
1313

1414
from common.test_tools import AssertMetricFixture
@@ -37,8 +37,7 @@
3737

3838

3939
@pytest.fixture(autouse=True)
40-
def reset_cache() -> typing.Generator[None, None, None]:
41-
yield
40+
def reset_cache() -> None:
4241
cache.clear()
4342

4443

@@ -73,38 +72,49 @@ def _sleep_task(seconds: int) -> None:
7372
return _sleep_task
7473

7574

75+
@pytest.mark.django_db(databases=["default", "task_processor"])
7676
@pytest.mark.task_processor_mode
77+
@pytest.mark.parametrize(
78+
"database",
79+
["default", "task_processor"],
80+
)
7781
def test_run_task_runs_task_and_creates_task_run_object_when_success(
82+
database: str,
7883
dummy_task: TaskHandler[[str, str]],
7984
) -> None:
8085
# Given
81-
task = Task.create(
82-
dummy_task.task_identifier,
83-
scheduled_for=timezone.now(),
84-
)
85-
task.save()
86+
task = Task.create(dummy_task.task_identifier, scheduled_for=timezone.now())
87+
task.save(using=database)
8688

8789
# When
88-
task_runs = run_tasks("default")
90+
task_runs = run_tasks(database)
8991

9092
# Then
9193
assert cache.get(DEFAULT_CACHE_KEY)
9294

93-
assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1
95+
assert (
96+
len(task_runs) == TaskRun.objects.using(database).filter(task=task).count() == 1
97+
)
9498
task_run = task_runs[0]
9599
assert task_run.result == TaskResult.SUCCESS.value
96100
assert task_run.started_at
97101
assert task_run.finished_at
98102
assert task_run.error_details is None
99103

100-
task.refresh_from_db()
104+
task.refresh_from_db(using=database)
101105
assert task.completed
102106

103107

108+
@pytest.mark.django_db(databases=["default", "task_processor"])
104109
@pytest.mark.task_processor_mode
110+
@pytest.mark.parametrize(
111+
"database",
112+
["default", "task_processor"],
113+
)
105114
def test_run_task_kills_task_after_timeout(
106-
sleep_task: TaskHandler[[int]],
107115
caplog: pytest.LogCaptureFixture,
116+
database: str,
117+
sleep_task: TaskHandler[[int]],
108118
) -> None:
109119
# Given
110120
task = Task.create(
@@ -113,21 +123,23 @@ def test_run_task_kills_task_after_timeout(
113123
args=(1,),
114124
timeout=timedelta(microseconds=1),
115125
)
116-
task.save()
126+
task.save(using=database)
117127

118128
# When
119-
task_runs = run_tasks("default")
129+
task_runs = run_tasks(database)
120130

121131
# Then
122-
assert len(task_runs) == TaskRun.objects.filter(task=task).count() == 1
132+
assert (
133+
len(task_runs) == TaskRun.objects.using(database).filter(task=task).count() == 1
134+
)
123135
task_run = task_runs[0]
124136
assert task_run.result == TaskResult.FAILURE.value
125137
assert task_run.started_at
126138
assert task_run.finished_at is None
127139
assert task_run.error_details
128140
assert "TimeoutError" in task_run.error_details
129141

130-
task.refresh_from_db()
142+
task.refresh_from_db(using=database)
131143

132144
assert task.completed is False
133145
assert task.num_failures == 1
@@ -139,12 +151,20 @@ def test_run_task_kills_task_after_timeout(
139151
)
140152

141153

142-
@pytest.mark.django_db
154+
@pytest.mark.django_db(databases=["default", "task_processor"])
143155
@pytest.mark.task_processor_mode
156+
@pytest.mark.parametrize(
157+
"database",
158+
["default", "task_processor"],
159+
)
144160
def test_run_recurring_task_kills_task_after_timeout(
145161
caplog: pytest.LogCaptureFixture,
162+
database: str,
163+
settings: SettingsWrapper,
146164
) -> None:
147165
# Given
166+
settings.TASK_PROCESSOR_DATABASES = [database]
167+
148168
@register_recurring_task(
149169
run_every=timedelta(seconds=1), timeout=timedelta(microseconds=1)
150170
)
@@ -157,18 +177,22 @@ def _dummy_recurring_task() -> None:
157177
task_identifier="test_unit_task_processor_processor._dummy_recurring_task",
158178
)
159179
# When
160-
task_runs = run_recurring_tasks("default")
180+
task_runs = run_recurring_tasks(database)
161181

162182
# Then
163-
assert len(task_runs) == RecurringTaskRun.objects.filter(task=task).count() == 1
183+
assert (
184+
len(task_runs)
185+
== RecurringTaskRun.objects.using(database).filter(task=task).count()
186+
== 1
187+
)
164188
task_run = task_runs[0]
165189
assert task_run.result == TaskResult.FAILURE.value
166190
assert task_run.started_at
167191
assert task_run.finished_at is None
168192
assert task_run.error_details
169193
assert "TimeoutError" in task_run.error_details
170194

171-
task.refresh_from_db()
195+
task.refresh_from_db(using=database)
172196

173197
assert task.locked_at is None
174198
assert task.is_locked is False
@@ -179,6 +203,9 @@ def _dummy_recurring_task() -> None:
179203
)
180204

181205

206+
# TODO: Need to parametrize all/most tests below to run on both databases
207+
208+
182209
@pytest.mark.django_db
183210
@pytest.mark.task_processor_mode
184211
def test_run_recurring_tasks_runs_task_and_creates_recurring_task_run_object_when_success() -> (
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# TODO

0 commit comments

Comments
 (0)