@@ -43,7 +43,6 @@ def run_task(pk: str) -> TaskResultStatus:
4343 # Seems like this might be a welcome addition:
4444 # https://discuss.python.org/t/adding-finalizer-to-the-threading-library/54186
4545 connection .close ()
46- pass
4746
4847
4948@task
@@ -63,6 +62,7 @@ def __init__(
6362 worker_id : str | None = None ,
6463 backend : str = DEFAULT_TASK_BACKEND_ALIAS ,
6564 loop_delay : float = 0.5 ,
65+ init_periodic : bool = True ,
6666 ):
6767 self .workers = workers
6868 self .executor = concurrent .futures .ThreadPoolExecutor (max_workers = workers )
@@ -78,15 +78,20 @@ def __init__(
7878 self .backend = task_backends [backend ]
7979 if not isinstance (self .backend , DatabaseBackend ):
8080 raise ImproperlyConfigured ("Backend must be a `DatabaseBackend`" )
81+ # Signaled when the runner is ready and processing tasks.
82+ self .ready = threading .Event ()
8183 # Signaled when the runner should stop.
8284 self .stopsign = threading .Event ()
83- # Signaled when the queue is empty (no READY tasks).
85+ # Signaled when the runner is finished stopping.
86+ self .finished = threading .Event ()
87+ # Signaled each time the queue is empty (no READY tasks).
8488 self .empty = threading .Event ()
8589 # Covers `self.tasks`, `self.seen_modules`, and `self.processed` access.
8690 self .lock = threading .Lock ()
8791 # Allows callers to block on a single task being completed.
8892 self .waiting : dict [str , threading .Event ] = {}
8993 self .periodic : dict [str , Periodic ] = {}
94+ self .should_init_periodic = init_periodic
9095 if retain := self .backend .options .get ("retain" ):
9196 # If the task backend specifies a retention period, schedule a periodic task
9297 # to delete finished tasks older than that period.
@@ -156,20 +161,50 @@ def task_done(
156161 except Exception as ex :
157162 logger .info (f"Task { task_path } ({ pk } ) raised { ex } " )
158163
164+ if was_periodic and (schedule := self .periodic .get (task_path )):
165+ after = timezone .make_aware (schedule .next ())
166+ # Since this can run in the task's thread, we need to clean up the
167+ # connection afterwards since it may not be closed at the end of `run`.
168+ with connection .temporary_connection ():
169+ t = ScheduledTask .objects .create (
170+ task_path = task_path ,
171+ args = schedule .args ,
172+ kwargs = schedule .kwargs ,
173+ backend = self .backend .alias ,
174+ run_after = after ,
175+ periodic = True ,
176+ )
177+ logger .info (f"Re-scheduled { t } for { after } " )
178+
179+ # If anyone is waiting on this task, wake them up.
159180 if event := self .waiting .get (pk ):
160181 event .set ()
161182
162- if was_periodic and (schedule := self .periodic .get (task_path )):
163- after = timezone .make_aware (schedule .next ())
164- t = ScheduledTask .objects .create (
165- task_path = task_path ,
166- args = schedule .args ,
167- kwargs = schedule .kwargs ,
168- backend = self .backend .alias ,
169- run_after = after ,
170- periodic = True ,
171- )
172- logger .info (f"Re-scheduled { t } for { after } " )
183+ def submit_task (self , task : ScheduledTask , start : bool = True ) -> TaskResult :
184+ """
185+ Submits a `ScheduledTask` for execution, marking it as RUNNING and setting its
186+ `started_at` timestamp if `start=True`.
187+ """
188+ if start :
189+ task .status = TaskResultStatus .RUNNING
190+ task .started_at = timezone .now ()
191+ task .worker_ids .append (self .worker_id )
192+ task .save (update_fields = ["status" , "started_at" , "worker_ids" ])
193+ logger .debug (f"Submitting { task } for execution" )
194+ f = self .executor .submit (run_task , task .task_id )
195+ with self .lock :
196+ # Keep track of task modules we've seen, so we can reload them.
197+ self .seen_modules .add (task .task_path .rsplit ("." , 1 )[0 ])
198+ self .tasks [task .task_id ] = f
199+ f .add_done_callback (
200+ functools .partial (
201+ self .task_done ,
202+ task .task_id ,
203+ task .task_path ,
204+ task .periodic ,
205+ ),
206+ )
207+ return task .result
173208
174209 def schedule_tasks (self ) -> float :
175210 """
@@ -193,20 +228,8 @@ def schedule_tasks(self) -> float:
193228 self .empty .clear ()
194229
195230 for t in tasks :
196- logger .debug (f"Submitting { t } for execution" )
197- f = self .executor .submit (run_task , t .task_id )
198- with self .lock :
199- # Keep track of task modules we've seen, so we can reload them.
200- self .seen_modules .add (t .task_path .rsplit ("." , 1 )[0 ])
201- self .tasks [t .task_id ] = f
202- f .add_done_callback (
203- functools .partial (
204- self .task_done ,
205- t .task_id ,
206- t .task_path ,
207- t .periodic ,
208- ),
209- )
231+ # get_tasks starts all of the returned tasks atomically, no need to here.
232+ self .submit_task (t , start = False )
210233
211234 if len (tasks ) >= available :
212235 # We got a full batch, try again immediately.
@@ -244,7 +267,12 @@ def run(self):
244267 """
245268 logger .info (f"Starting task runner with { self .workers } workers" )
246269 self .processed = 0
247- self .init_periodic ()
270+ if self .should_init_periodic :
271+ with transaction .atomic (durable = True ):
272+ self .init_periodic ()
273+ transaction .on_commit (self .ready .set )
274+ else :
275+ self .ready .set ()
248276 try :
249277 while not self .stopsign .is_set ():
250278 delay = self .schedule_tasks ()
@@ -254,6 +282,7 @@ def run(self):
254282 finally :
255283 self .executor .shutdown ()
256284 connection .close ()
285+ self .finished .set ()
257286
258287 def wait_for (self , result : TaskResult , timeout : float | None = None ) -> bool :
259288 """
0 commit comments