Skip to content

Commit 9de8df5

Browse files
committed
exclude backed off tasks from retries, revert registry changes
1 parent b289839 commit 9de8df5

File tree

5 files changed

+11
-27
lines changed

5 files changed

+11
-27
lines changed

src/task_processor/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
f,
5050
task_name,
5151
)
52-
task_registry.register_task(task_identifier, self)
52+
task_registry.register_task(task_identifier, f)
5353

5454
def __call__(
5555
self,

src/task_processor/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def run(self) -> None:
7474
@property
7575
def callable(self) -> TaskCallable[typing.Any]:
7676
task = get_task(self.task_identifier)
77-
return task.task_callable
77+
return task.task_function
7878

7979

8080
class Task(AbstractBaseTask):

src/task_processor/processor.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def run_tasks(database: str, num_tasks: int = 1) -> list[TaskRun]:
5151
if executed_tasks:
5252
Task.objects.using(database).bulk_update(
5353
executed_tasks,
54-
fields=["completed", "num_failures", "is_locked"],
54+
fields=["completed", "num_failures", "is_locked", "scheduled_for"],
5555
)
5656

5757
if task_runs:
@@ -165,18 +165,10 @@ def _run_task(
165165
)
166166
if typing.TYPE_CHECKING:
167167
assert isinstance(task, Task)
168-
assert registered_task.task_handler
169168
delay_until = e.delay_until or timezone.now() + timedelta(
170169
seconds=settings.TASK_BACKOFF_DEFAULT_DELAY_SECONDS,
171170
)
172-
assert registered_task.task_handler, (
173-
"Attempt to back off a recurring task (currently not supported)"
174-
)
175-
registered_task.task_handler.delay(
176-
delay_until=delay_until,
177-
args=task.args,
178-
kwargs=task.kwargs,
179-
)
171+
task.scheduled_for = delay_until
180172
logger.info(
181173
"Backoff requested. Task '%s' set to retry at %s",
182174
task_identifier,

src/task_processor/task_registry.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from task_processor.exceptions import TaskProcessingError
77
from task_processor.types import TaskCallable
88

9-
if typing.TYPE_CHECKING:
10-
from task_processor.decorators import TaskHandler # noqa: F401
11-
129
logger = logging.getLogger(__name__)
1310

1411

@@ -20,9 +17,8 @@ class TaskType(enum.Enum):
2017
@dataclass
2118
class RegisteredTask:
2219
task_identifier: str
23-
task_callable: TaskCallable[typing.Any]
20+
task_function: TaskCallable[typing.Any]
2421
task_type: TaskType = TaskType.STANDARD
25-
task_handler: "TaskHandler[typing.Any] | None" = None
2622
task_kwargs: dict[str, typing.Any] | None = None
2723

2824

@@ -60,21 +56,20 @@ def get_task(task_identifier: str) -> RegisteredTask:
6056

6157
def register_task(
6258
task_identifier: str,
63-
task_handler: "TaskHandler[typing.Any]",
59+
callable_: TaskCallable[typing.Any],
6460
) -> None:
6561
global registered_tasks
6662

6763
registered_task = RegisteredTask(
6864
task_identifier=task_identifier,
69-
task_handler=task_handler,
70-
task_callable=task_handler.unwrapped,
65+
task_function=callable_,
7166
)
7267
registered_tasks[task_identifier] = registered_task
7368

7469

7570
def register_recurring_task(
7671
task_identifier: str,
77-
task_callable: TaskCallable[typing.Any],
72+
callable_: TaskCallable[typing.Any],
7873
**task_kwargs: typing.Any,
7974
) -> None:
8075
global registered_tasks
@@ -83,7 +78,7 @@ def register_recurring_task(
8378

8479
registered_task = RegisteredTask(
8580
task_identifier=task_identifier,
86-
task_callable=task_callable,
81+
task_function=callable_,
8782
task_type=TaskType.RECURRING,
8883
task_kwargs=task_kwargs,
8984
)

tests/unit/task_processor/test_unit_task_processor_processor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -684,11 +684,8 @@ def backoff_task() -> None:
684684
assert [
685685
record.message for record in caplog.records if record.levelno == logging.INFO
686686
] == [expected_log_message]
687-
assert Task.objects.using(current_database).count() == 2
688-
assert (
689-
Task.objects.using(current_database).latest("created_at").scheduled_for
690-
== expected_scheduled_for
691-
)
687+
task.refresh_from_db(using=current_database)
688+
assert task.scheduled_for == expected_scheduled_for
692689

693690

694691
@pytest.mark.multi_database

0 commit comments

Comments
 (0)