1717import functools
1818import logging
1919from abc import ABC , abstractmethod
20+ from collections .abc import Callable
2021from dataclasses import dataclass , field
21- from typing import Callable , Final
22+ from typing import Final , TypeAlias
2223
2324import arrow
2425import networkx as nx
8788 str
8889] = "computational-scheduler-{user_id}:{project_id}:{iteration}"
8990
91+ PipelineSchedulingTask : TypeAlias = asyncio .Task
92+ PipelineSchedulingWakeUpEvent : TypeAlias = asyncio .Event
93+
9094
9195@dataclass (frozen = True , slots = True )
9296class SortedTasks :
@@ -139,8 +143,8 @@ async def _triage_changed_tasks(
139143
140144@dataclass (kw_only = True )
141145class ScheduledPipelineParams :
142- scheduler_task : asyncio .Task | None = None
143- scheduler_waker : asyncio .Event = field ( default_factory = asyncio . Event )
146+ scheduler_task : asyncio .Task
147+ scheduler_waker : asyncio .Event
144148
145149 def wake_up (self ) -> None :
146150 self .scheduler_waker .set ()
@@ -169,31 +173,19 @@ async def restore_scheduling_from_db(self) -> None:
169173 filter_by_state = SCHEDULED_STATES
170174 )
171175
172- self ._scheduled_pipelines |= {
173- (
174- run .user_id ,
175- run .project_uuid ,
176- run .iteration ,
177- ): ScheduledPipelineParams ()
178- for run in comp_runs
179- }
180-
181- for (
182- user_id ,
183- project_id ,
184- iteration ,
185- ), params in self ._scheduled_pipelines .items ():
186- self ._start_scheduling (params , user_id , project_id , iteration )
187-
188- async def start_scheduling (self ) -> None :
189- await self .restore_scheduling_from_db ()
190-
191- for (
192- user_id ,
193- project_id ,
194- iteration ,
195- ), params in self ._scheduled_pipelines .items ():
196- self ._start_scheduling (params , user_id , project_id , iteration )
176+ for run in comp_runs :
177+ task , wake_up_event = self ._start_scheduling (
178+ run .user_id , run .project_uuid , run .iteration
179+ )
180+ self ._scheduled_pipelines |= {
181+ (
182+ run .user_id ,
183+ run .project_uuid ,
184+ run .iteration ,
185+ ): ScheduledPipelineParams (
186+ scheduler_task = task , scheduler_waker = wake_up_event
187+ )
188+ }
197189
198190 async def run_new_pipeline (
199191 self ,
@@ -224,9 +216,12 @@ async def run_new_pipeline(
224216 metadata = run_metadata ,
225217 use_on_demand_clusters = use_on_demand_clusters ,
226218 )
219+ task , wake_up_event = self ._start_scheduling (
220+ user_id , project_id , new_run .iteration
221+ )
227222 self ._scheduled_pipelines [
228223 (user_id , project_id , new_run .iteration )
229- ] = pipeline_params = ScheduledPipelineParams ()
224+ ] = ScheduledPipelineParams (scheduler_task = task , scheduler_waker = wake_up_event )
230225 await publish_project_log (
231226 self .rabbitmq_client ,
232227 user_id ,
@@ -235,8 +230,6 @@ async def run_new_pipeline(
235230 log_level = logging .INFO ,
236231 )
237232
238- self ._start_scheduling (pipeline_params , user_id , project_id , new_run .iteration )
239-
240233 async def stop_pipeline (
241234 self , user_id : UserID , project_id : ProjectID , iteration : int | None = None
242235 ) -> None :
@@ -293,11 +286,10 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati
293286
294287 def _start_scheduling (
295288 self ,
296- pipeline_params : ScheduledPipelineParams ,
297289 user_id : UserID ,
298290 project_id : ProjectID ,
299291 iteration : Iteration ,
300- ) -> None :
292+ ) -> tuple [ PipelineSchedulingTask , PipelineSchedulingWakeUpEvent ] :
301293 async def _exclusive_safe_schedule_pipeline (
302294 * ,
303295 user_id : UserID ,
@@ -313,20 +305,22 @@ async def _exclusive_safe_schedule_pipeline(
313305 wake_up_callback = wake_up_callback ,
314306 )
315307
316- pipeline_params .scheduler_task = start_periodic_task (
308+ pipeline_wake_up_event = asyncio .Event ()
309+ pipeline_task = start_periodic_task (
317310 functools .partial (
318311 _exclusive_safe_schedule_pipeline ,
319312 user_id = user_id ,
320313 project_id = project_id ,
321314 iteration = iteration ,
322- wake_up_callback = pipeline_params . wake_up ,
315+ wake_up_callback = pipeline_wake_up_event . set ,
323316 ),
324317 interval = _SCHEDULER_INTERVAL ,
325318 task_name = _TASK_NAME_TEMPLATE .format (
326319 user_id = user_id , project_id = project_id , iteration = iteration
327320 ),
328- early_wake_up_event = pipeline_params . scheduler_waker ,
321+ early_wake_up_event = pipeline_wake_up_event ,
329322 )
323+ return pipeline_task , pipeline_wake_up_event
330324
331325 async def _get_pipeline_dag (self , project_id : ProjectID ) -> nx .DiGraph :
332326 comp_pipeline_repo = CompPipelinesRepository .instance (self .db_engine )
0 commit comments