Skip to content

Commit 399034b

Browse files
committed
Periodic task tests, ready/finished runner events
1 parent 05164a6 commit 399034b

File tree

7 files changed

+129
-42
lines changed

7 files changed

+129
-42
lines changed

src/dbtasks/periodic.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
from datetime import datetime
1+
from datetime import datetime, timedelta
22
from typing import Callable
33

4-
from .schedule import Crontab, Schedule
4+
from .schedule import Crontab, Every, Schedule
55

66

77
class Periodic:
8+
schedule: Schedule
9+
810
def __init__(
911
self,
10-
spec: Schedule | str,
12+
spec: Schedule | str | int | timedelta,
1113
args: list | tuple | Callable[[], list | tuple] | None = None,
1214
kwargs: dict | Callable[[], dict] | None = None,
1315
):
14-
self.schedule = spec if isinstance(spec, Schedule) else Crontab(spec)
16+
if isinstance(spec, Schedule):
17+
self.schedule = spec
18+
elif isinstance(spec, str) and len(spec.split()) == 5:
19+
self.schedule = Crontab(spec)
20+
else:
21+
self.schedule = Every(spec)
1522
self._args = args
1623
self._kwargs = kwargs
1724

src/dbtasks/runner.py

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def run_task(pk: str) -> TaskResultStatus:
4343
# Seems like this might be a welcome addition:
4444
# https://discuss.python.org/t/adding-finalizer-to-the-threading-library/54186
4545
connection.close()
46-
pass
4746

4847

4948
@task
@@ -63,6 +62,7 @@ def __init__(
6362
worker_id: str | None = None,
6463
backend: str = DEFAULT_TASK_BACKEND_ALIAS,
6564
loop_delay: float = 0.5,
65+
init_periodic: bool = True,
6666
):
6767
self.workers = workers
6868
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers)
@@ -78,15 +78,20 @@ def __init__(
7878
self.backend = task_backends[backend]
7979
if not isinstance(self.backend, DatabaseBackend):
8080
raise ImproperlyConfigured("Backend must be a `DatabaseBackend`")
81+
# Signaled when the runner is ready and processing tasks.
82+
self.ready = threading.Event()
8183
# Signaled when the runner should stop.
8284
self.stopsign = threading.Event()
83-
# Signaled when the queue is empty (no READY tasks).
85+
# Signaled when the runner is finished stopping.
86+
self.finished = threading.Event()
87+
# Signaled each time the queue is empty (no READY tasks).
8488
self.empty = threading.Event()
8589
# Covers `self.tasks`, `self.seen_modules`, and `self.processed` access.
8690
self.lock = threading.Lock()
8791
# Allows callers to block on a single task being completed.
8892
self.waiting: dict[str, threading.Event] = {}
8993
self.periodic: dict[str, Periodic] = {}
94+
self.should_init_periodic = init_periodic
9095
if retain := self.backend.options.get("retain"):
9196
# If the task backend specifies a retention period, schedule a periodic task
9297
# to delete finished tasks older than that period.
@@ -156,20 +161,50 @@ def task_done(
156161
except Exception as ex:
157162
logger.info(f"Task {task_path} ({pk}) raised {ex}")
158163

164+
if was_periodic and (schedule := self.periodic.get(task_path)):
165+
after = timezone.make_aware(schedule.next())
166+
# Since this can run in the task's thread, we need to clean up the
167+
# connection afterwards since it may not be closed at the end of `run`.
168+
with connection.temporary_connection():
169+
t = ScheduledTask.objects.create(
170+
task_path=task_path,
171+
args=schedule.args,
172+
kwargs=schedule.kwargs,
173+
backend=self.backend.alias,
174+
run_after=after,
175+
periodic=True,
176+
)
177+
logger.info(f"Re-scheduled {t} for {after}")
178+
179+
# If anyone is waiting on this task, wake them up.
159180
if event := self.waiting.get(pk):
160181
event.set()
161182

162-
if was_periodic and (schedule := self.periodic.get(task_path)):
163-
after = timezone.make_aware(schedule.next())
164-
t = ScheduledTask.objects.create(
165-
task_path=task_path,
166-
args=schedule.args,
167-
kwargs=schedule.kwargs,
168-
backend=self.backend.alias,
169-
run_after=after,
170-
periodic=True,
171-
)
172-
logger.info(f"Re-scheduled {t} for {after}")
183+
def submit_task(self, task: ScheduledTask, start: bool = True) -> TaskResult:
184+
"""
185+
Submits a `ScheduledTask` for execution, marking it as RUNNING and setting its
186+
`started_at` timestamp if `start=True`.
187+
"""
188+
if start:
189+
task.status = TaskResultStatus.RUNNING
190+
task.started_at = timezone.now()
191+
task.worker_ids.append(self.worker_id)
192+
task.save(update_fields=["status", "started_at", "worker_ids"])
193+
logger.debug(f"Submitting {task} for execution")
194+
f = self.executor.submit(run_task, task.task_id)
195+
with self.lock:
196+
# Keep track of task modules we've seen, so we can reload them.
197+
self.seen_modules.add(task.task_path.rsplit(".", 1)[0])
198+
self.tasks[task.task_id] = f
199+
f.add_done_callback(
200+
functools.partial(
201+
self.task_done,
202+
task.task_id,
203+
task.task_path,
204+
task.periodic,
205+
),
206+
)
207+
return task.result
173208

174209
def schedule_tasks(self) -> float:
175210
"""
@@ -193,20 +228,8 @@ def schedule_tasks(self) -> float:
193228
self.empty.clear()
194229

195230
for t in tasks:
196-
logger.debug(f"Submitting {t} for execution")
197-
f = self.executor.submit(run_task, t.task_id)
198-
with self.lock:
199-
# Keep track of task modules we've seen, so we can reload them.
200-
self.seen_modules.add(t.task_path.rsplit(".", 1)[0])
201-
self.tasks[t.task_id] = f
202-
f.add_done_callback(
203-
functools.partial(
204-
self.task_done,
205-
t.task_id,
206-
t.task_path,
207-
t.periodic,
208-
),
209-
)
231+
# get_tasks starts all of the returned tasks atomically, no need to here.
232+
self.submit_task(t, start=False)
210233

211234
if len(tasks) >= available:
212235
# We got a full batch, try again immediately.
@@ -244,7 +267,12 @@ def run(self):
244267
"""
245268
logger.info(f"Starting task runner with {self.workers} workers")
246269
self.processed = 0
247-
self.init_periodic()
270+
if self.should_init_periodic:
271+
with transaction.atomic(durable=True):
272+
self.init_periodic()
273+
transaction.on_commit(self.ready.set)
274+
else:
275+
self.ready.set()
248276
try:
249277
while not self.stopsign.is_set():
250278
delay = self.schedule_tasks()
@@ -254,6 +282,7 @@ def run(self):
254282
finally:
255283
self.executor.shutdown()
256284
connection.close()
285+
self.finished.set()
257286

258287
def wait_for(self, result: TaskResult, timeout: float | None = None) -> bool:
259288
"""

src/dbtasks/testing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,18 @@ def setUpClass(cls):
1515
super().setUpClass()
1616
# Run with a very short loop delay to speed up tests. None of our test tasks
1717
# take very long, so there's not much point in waiting aside from not flooding
18-
# the database with queries for new tasks.
19-
cls.runner = Runner(workers=os.cpu_count() - 1, loop_delay=0.01)
20-
cls.runner_thread = threading.Thread(target=cls.runner.run)
21-
cls.runner_thread.start()
18+
# the database with queries for new tasks. Also we don't initialize periodic
19+
# tasks - individual tests can call `self.runner.init_periodic()`.
20+
cls.runner = Runner(
21+
workers=max(1, (os.cpu_count() or 4) - 1),
22+
loop_delay=0.01,
23+
init_periodic=False,
24+
)
25+
threading.Thread(target=cls.runner.run).start()
26+
cls.runner.ready.wait()
2227

2328
@classmethod
2429
def tearDownClass(cls):
2530
super().tearDownClass()
2631
cls.runner.stop()
27-
cls.runner_thread.join()
32+
cls.runner.finished.wait()

tests/settings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import datetime
21
import os
32

43
INSTALLED_APPS = ["dbtasks", "tests"]
@@ -45,8 +44,9 @@
4544
"BACKEND": "dbtasks.backend.DatabaseBackend",
4645
"OPTIONS": {
4746
"immediate": False,
48-
"retain": datetime.timedelta(days=7),
49-
"periodic": {},
47+
"periodic": {
48+
"tests.tasks.maintenance": "1h",
49+
},
5050
},
5151
},
5252
}

tests/tasks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,8 @@
99
def send_mail(to: str, message: str):
1010
logger.info(f"Sending mail to {to}: {message}")
1111
return {"sent": True}
12+
13+
14+
@task
15+
def maintenance():
16+
pass

tests/tests/test_tasks.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from django.tasks import TaskResult, TaskResultStatus
22

3-
from ..tasks import send_mail
3+
from dbtasks.models import ScheduledTask
4+
5+
from ..tasks import maintenance, send_mail
46
from ..utils import LoggedRunnerTestCase
57

68

@@ -23,3 +25,41 @@ def test_lots_of_tasks(self):
2325
self.runner.wait()
2426
self.assertEqual(self.runner.processed, 100)
2527
self.assertEqual(set(self.task_logs["INFO"]), expected)
28+
29+
def test_periodic(self):
30+
self.assertIn(maintenance.module_path, self.runner.periodic)
31+
self.runner.init_periodic()
32+
first = ScheduledTask.objects.get(
33+
task_path=maintenance.module_path,
34+
status=TaskResultStatus.READY,
35+
periodic=True,
36+
)
37+
# Run the initial scheduled task manually.
38+
result = self.runner.submit_task(first)
39+
self.assertTrue(self.runner.wait_for(result))
40+
first.refresh_from_db()
41+
self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)
42+
self.assertEqual(first.status, TaskResultStatus.SUCCESSFUL)
43+
# Make sure a second periodic task got scheduled when the first completed.
44+
second = ScheduledTask.objects.get(
45+
task_path=maintenance.module_path,
46+
status=TaskResultStatus.READY,
47+
periodic=True,
48+
)
49+
self.assertNotEqual(first.id, second.id)
50+
self.assertGreater(second.enqueued_at, first.enqueued_at)
51+
# Now run the maintenance task manually, not periodically.
52+
result: TaskResult = maintenance.enqueue()
53+
self.assertTrue(self.runner.wait_for(result))
54+
manual = ScheduledTask.objects.get(pk=result.id)
55+
# Make sure the manual run was not marked periodic, and that no new tasks were
56+
# automatically scheduled afterwards.
57+
self.assertFalse(manual.periodic)
58+
self.assertEqual(
59+
ScheduledTask.objects.filter(
60+
task_path=maintenance.module_path, status=TaskResultStatus.READY
61+
).count(),
62+
1,
63+
)
64+
second.refresh_from_db()
65+
self.assertEqual(second.status, TaskResultStatus.READY)

tests/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ class LoggedRunnerTestCase(RunnerTestCase):
77
def setUp(self):
88
with stashlock:
99
stash.clear()
10+
super().setUp()
1011

1112
def tearDown(self):
12-
pass
13+
super().tearDown()
1314

1415
@property
1516
def logs(self) -> dict:

0 commit comments

Comments
 (0)