Skip to content

Commit aff50b2

Browse files
committed
feat: Support task backoff
1 parent b6a250b commit aff50b2

File tree

5 files changed

+58
-16
lines changed

5 files changed

+58
-16
lines changed

settings/dev.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
ENABLE_CLEAN_UP_OLD_TASKS = True
5959
ENABLE_TASK_PROCESSOR_HEALTH_CHECK = True
6060
RECURRING_TASK_RUN_RETENTION_DAYS = 15
61+
TASK_BACKOFF_DEFAULT_SECONDS = 5
6162
TASK_DELETE_BATCH_SIZE = 2000
6263
TASK_DELETE_INCLUDE_FAILED_TASKS = False
6364
TASK_DELETE_RETENTION_DAYS = 15

src/task_processor/exceptions.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from datetime import datetime
2+
3+
14
class TaskProcessingError(Exception):
25
pass
36

@@ -6,5 +9,21 @@ class InvalidArgumentsError(TaskProcessingError):
69
pass
710

811

12+
class TaskBackoffError(TaskProcessingError):
13+
"""
14+
Raise this exception inside a task to indicate that it should be retried after a delay.
15+
This is typically used when a task fails due to a temporary issue, such as
16+
a network error or a service being unavailable.
17+
"""
18+
19+
def __init__(
20+
self,
21+
delay_until: datetime | None = None,
22+
) -> None:
23+
super().__init__()
24+
delay_until = delay_until
25+
self.delay_until = delay_until
26+
27+
928
class TaskQueueFullError(Exception):
1029
pass

src/task_processor/models.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from django.db import models
88
from django.utils import timezone
99

10-
from task_processor.exceptions import TaskProcessingError, TaskQueueFullError
10+
from task_processor.exceptions import TaskQueueFullError
1111
from task_processor.managers import RecurringTaskManager, TaskManager
12-
from task_processor.task_registry import registered_tasks
12+
from task_processor.task_registry import get_task, registered_tasks
1313
from task_processor.types import TaskCallable
1414

1515
_django_json_encoder_default = DjangoJSONEncoder().default
@@ -75,15 +75,8 @@ def run(self) -> None:
7575

7676
@property
7777
def callable(self) -> TaskCallable[typing.Any]:
78-
try:
79-
task = registered_tasks[self.task_identifier]
80-
return task.task_function
81-
except KeyError as e:
82-
raise TaskProcessingError(
83-
"No task registered with identifier '%s'. Ensure your task is "
84-
"decorated with @register_task_handler.",
85-
self.task_identifier,
86-
) from e
78+
task = get_task(self.task_identifier)
79+
return task.task_handler.unwrapped
8780

8881

8982
class Task(AbstractBaseTask):

src/task_processor/processor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from contextlib import ExitStack
66
from datetime import timedelta
77

8+
from dateutil.relativedelta import relativedelta
89
from django.conf import settings
910
from django.utils import timezone
1011

1112
from task_processor import metrics
13+
from task_processor.exceptions import TaskBackoffError
1214
from task_processor.managers import RecurringTaskManager, TaskManager
1315
from task_processor.models import (
1416
AbstractBaseTask,
@@ -120,6 +122,7 @@ def _run_task(
120122
ctx.enter_context(timer)
121123

122124
task_identifier = task.task_identifier
125+
registered_task = get_task(task_identifier)
123126

124127
logger.debug(
125128
f"Running task {task_identifier} id={task.pk} args={task.args} kwargs={task.kwargs}"
@@ -157,9 +160,24 @@ def _run_task(
157160
exc_info=True,
158161
)
159162

163+
if isinstance(e, TaskBackoffError):
164+
delay_until = e.delay_until or timezone.now() + relativedelta(
165+
seconds=settings.TASK_BACKOFF_DEFAULT_SECONDS,
166+
)
167+
registered_task.task_handler.delay(
168+
delay_until=delay_until,
169+
args=task.args,
170+
kwargs=task.kwargs,
171+
)
172+
logger.info(
173+
"Backoff requested. Task '%s' set to retry at %s",
174+
task_identifier,
175+
delay_until,
176+
)
177+
160178
labels = {
161179
"task_identifier": task_identifier,
162-
"task_type": get_task(task_identifier).task_type.value.lower(),
180+
"task_type": registered_task.task_type.value.lower(),
163181
"result": result.lower(),
164182
}
165183

src/task_processor/task_registry.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
import typing
44
from dataclasses import dataclass
55

6+
from task_processor.exceptions import TaskProcessingError
67
from task_processor.types import TaskCallable
78

9+
if typing.TYPE_CHECKING:
10+
from task_processor.decorators import TaskHandler # noqa: F401
11+
812
logger = logging.getLogger(__name__)
913

1014

@@ -16,7 +20,7 @@ class TaskType(enum.Enum):
1620
@dataclass
1721
class RegisteredTask:
1822
task_identifier: str
19-
task_function: TaskCallable[typing.Any]
23+
task_handler: "TaskHandler"
2024
task_type: TaskType = TaskType.STANDARD
2125
task_kwargs: dict[str, typing.Any] | None = None
2226

@@ -43,18 +47,25 @@ def initialise() -> None:
4347
def get_task(task_identifier: str) -> RegisteredTask:
4448
global registered_tasks
4549

46-
return registered_tasks[task_identifier]
50+
try:
51+
return registered_tasks[task_identifier]
52+
except KeyError:
53+
raise TaskProcessingError(
54+
"No task registered with identifier '%s'. Ensure your task is "
55+
"decorated with @register_task_handler.",
56+
task_identifier,
57+
)
4758

4859

4960
def register_task(
5061
task_identifier: str,
51-
callable_: TaskCallable[typing.Any],
62+
task_handler: "TaskHandler",
5263
) -> None:
5364
global registered_tasks
5465

5566
registered_task = RegisteredTask(
5667
task_identifier=task_identifier,
57-
task_function=callable_,
68+
task_handler=task_handler,
5869
)
5970
registered_tasks[task_identifier] = registered_task
6071

0 commit comments

Comments
 (0)