Skip to content

Commit fb4e3f6

Browse files
committed
Pass tasks around directly, test failed tasks
1 parent 399034b commit fb4e3f6

File tree

4 files changed

+46
-35
lines changed

4 files changed

+46
-35
lines changed

src/dbtasks/management/commands/taskrunner.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,18 @@ def add_arguments(self, parser: CommandParser):
4141
default=0.5,
4242
help="Loop delay [default=0.5]",
4343
)
44+
parser.add_argument(
45+
"--no-periodic",
46+
action="store_false",
47+
default=True,
48+
dest="periodic",
49+
)
4450

4551
def handle(self, *args, **options):
4652
Runner(
4753
workers=options["workers"],
4854
worker_id=options["worker_id"],
4955
backend=options["backend"],
5056
loop_delay=options["delay"],
57+
init_periodic=options["periodic"],
5158
).run()

src/dbtasks/runner.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,23 @@
1717
task,
1818
task_backends,
1919
)
20+
from django.tasks.signals import task_finished, task_started
2021
from django.utils import timezone
2122
from django.utils.module_loading import import_string
2223

2324
from .backend import DatabaseBackend
2425
from .models import ScheduledTask
2526
from .periodic import Periodic
27+
from .schedule import Duration
2628

2729
logger = logging.getLogger(__name__)
2830

2931

30-
def run_task(pk: str) -> TaskResultStatus:
32+
def run_task(task: ScheduledTask) -> TaskResultStatus:
3133
"""
3234
Fetches, runs, and updates a `ScheduledTask`. Runs in a worker thread.
3335
"""
3436
try:
35-
task = ScheduledTask.objects.get(pk=pk)
3637
logger.info(f"Running {task}")
3738
return task.run_and_update()
3839
finally:
@@ -95,15 +96,7 @@ def __init__(
9596
if retain := self.backend.options.get("retain"):
9697
# If the task backend specifies a retention period, schedule a periodic task
9798
# to delete finished tasks older than that period.
98-
retain_secs = 0
99-
if isinstance(retain, int):
100-
retain_secs = retain
101-
elif isinstance(retain, datetime.timedelta):
102-
retain_secs = int(retain.total_seconds())
103-
else:
104-
raise ImproperlyConfigured(
105-
"Backend `retain` option should be an `int` or `timedelta`"
106-
)
99+
retain_secs = int(Duration(retain).total_seconds())
107100
self.periodic[f"{__name__}.cleanup"] = Periodic(
108101
"~ * * * *", args=[retain_secs]
109102
)
@@ -141,9 +134,7 @@ def get_tasks(self, number: int) -> list[ScheduledTask]:
141134

142135
def task_done(
143136
self,
144-
pk: str,
145-
task_path: str,
146-
was_periodic: bool,
137+
task: ScheduledTask,
147138
fut: concurrent.futures.Future,
148139
):
149140
"""
@@ -153,21 +144,21 @@ def task_done(
153144
"""
154145
with self.lock:
155146
self.processed += 1
156-
del self.tasks[pk]
147+
del self.tasks[task.task_id]
157148

158149
try:
159150
status = fut.result()
160-
logger.info(f"Task {task_path} ({pk}) finished with status {status}")
151+
logger.info(f"Task {task} finished with status {status}")
161152
except Exception as ex:
162-
logger.info(f"Task {task_path} ({pk}) raised {ex}")
153+
logger.info(f"Task {task} raised {ex}")
163154

164-
if was_periodic and (schedule := self.periodic.get(task_path)):
155+
if task.periodic and (schedule := self.periodic.get(task.task_path)):
165156
after = timezone.make_aware(schedule.next())
166157
# Since this can run in the task's thread, we need to clean up the
167158
# connection afterwards since it may not be closed at the end of `run`.
168159
with connection.temporary_connection():
169160
t = ScheduledTask.objects.create(
170-
task_path=task_path,
161+
task_path=task.task_path,
171162
args=schedule.args,
172163
kwargs=schedule.kwargs,
173164
backend=self.backend.alias,
@@ -176,34 +167,33 @@ def task_done(
176167
)
177168
logger.info(f"Re-scheduled {t} for {after}")
178169

170+
task_finished.send(type(self.backend), task_result=task.result)
171+
179172
# If anyone is waiting on this task, wake them up.
180-
if event := self.waiting.get(pk):
173+
if event := self.waiting.get(task.task_id):
181174
event.set()
182175

183176
def submit_task(self, task: ScheduledTask, start: bool = True) -> TaskResult:
184177
"""
185178
Submits a `ScheduledTask` for execution, marking it as RUNNING and setting its
186179
`started_at` timestamp if `start=True`.
180+
181+
Note that `task` is passed directly to a separate thread, so callers should not
182+
modify it until after the task is complete.
187183
"""
188184
if start:
189185
task.status = TaskResultStatus.RUNNING
190186
task.started_at = timezone.now()
191187
task.worker_ids.append(self.worker_id)
192188
task.save(update_fields=["status", "started_at", "worker_ids"])
193189
logger.debug(f"Submitting {task} for execution")
194-
f = self.executor.submit(run_task, task.task_id)
190+
task_started.send(type(self.backend), task_result=task.result)
191+
f = self.executor.submit(run_task, task)
195192
with self.lock:
196193
# Keep track of task modules we've seen, so we can reload them.
197194
self.seen_modules.add(task.task_path.rsplit(".", 1)[0])
198195
self.tasks[task.task_id] = f
199-
f.add_done_callback(
200-
functools.partial(
201-
self.task_done,
202-
task.task_id,
203-
task.task_path,
204-
task.periodic,
205-
),
206-
)
196+
f.add_done_callback(functools.partial(self.task_done, task))
207197
return task.result
208198

209199
def schedule_tasks(self) -> float:
@@ -242,13 +232,12 @@ def init_periodic(self):
242232
Removes any outstanding scheduled periodic tasks, and schedules the next runs
243233
for each.
244234
"""
245-
# First delete any un-started periodic tasks.
246235
ScheduledTask.objects.filter(
247236
status=TaskResultStatus.READY,
248237
periodic=True,
249238
).delete()
250-
# Then schedule the next run of each periodic task. Subsequent runs will be
251-
# scheduled on completion.
239+
# Schedule the next run of each periodic task. Subsequent runs will be scheduled
240+
# on completion.
252241
for task_path, schedule in self.periodic.items():
253242
after = timezone.make_aware(schedule.next())
254243
t = ScheduledTask.objects.create(
@@ -267,6 +256,7 @@ def run(self):
267256
"""
268257
logger.info(f"Starting task runner with {self.workers} workers")
269258
self.processed = 0
259+
self.stopsign.clear()
270260
if self.should_init_periodic:
271261
with transaction.atomic(durable=True):
272262
self.init_periodic()
@@ -282,6 +272,7 @@ def run(self):
282272
finally:
283273
self.executor.shutdown()
284274
connection.close()
275+
self.ready.clear()
285276
self.finished.set()
286277

287278
def wait_for(self, result: TaskResult, timeout: float | None = None) -> bool:
@@ -290,7 +281,7 @@ def wait_for(self, result: TaskResult, timeout: float | None = None) -> bool:
290281
"""
291282
if result.status in (TaskResultStatus.SUCCESSFUL, TaskResultStatus.FAILED):
292283
return True
293-
logger.info(f"Waiting for {result.id}...")
284+
logger.debug(f"Waiting for {result.id}...")
294285
event = threading.Event()
295286
self.waiting[result.id] = event
296287
success = event.wait(timeout)

tests/tasks.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,8 @@ def send_mail(to: str, message: str):
1414
@task
1515
def maintenance():
1616
pass
17+
18+
19+
@task
20+
def kaboom(msg: str):
21+
raise ValueError(msg)

tests/tests/test_tasks.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from dbtasks.models import ScheduledTask
44

5-
from ..tasks import maintenance, send_mail
5+
from ..tasks import kaboom, maintenance, send_mail
66
from ..utils import LoggedRunnerTestCase
77

88

@@ -23,7 +23,6 @@ def test_lots_of_tasks(self):
2323
send_mail.enqueue(f"user-{k}@example.com", "hello!")
2424
expected.add(f"Sending mail to user-{k}@example.com: hello!")
2525
self.runner.wait()
26-
self.assertEqual(self.runner.processed, 100)
2726
self.assertEqual(set(self.task_logs["INFO"]), expected)
2827

2928
def test_periodic(self):
@@ -63,3 +62,12 @@ def test_periodic(self):
6362
)
6463
second.refresh_from_db()
6564
self.assertEqual(second.status, TaskResultStatus.READY)
65+
66+
def test_failed_task(self):
67+
result: TaskResult = kaboom.enqueue("Boom goes the dynamite!")
68+
self.assertTrue(self.runner.wait_for(result))
69+
self.assertEqual(result.status, TaskResultStatus.FAILED)
70+
with self.assertRaises(ValueError):
71+
result.return_value
72+
self.assertEqual(len(result.errors), 1)
73+
self.assertEqual(result.errors[0].exception_class_path, "builtins.ValueError")

0 commit comments

Comments
 (0)