Skip to content

Commit 021009b

Browse files
committed
simplify
1 parent 8babcaf commit 021009b

File tree

1 file changed

+30
-36
lines changed
  • services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler

1 file changed

+30
-36
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
import functools
1818
import logging
1919
from abc import ABC, abstractmethod
20+
from collections.abc import Callable
2021
from dataclasses import dataclass, field
21-
from typing import Callable, Final
22+
from typing import Final, TypeAlias
2223

2324
import arrow
2425
import networkx as nx
@@ -87,6 +88,9 @@
8788
str
8889
] = "computational-scheduler-{user_id}:{project_id}:{iteration}"
8990

91+
PipelineSchedulingTask: TypeAlias = asyncio.Task
92+
PipelineSchedulingWakeUpEvent: TypeAlias = asyncio.Event
93+
9094

9195
@dataclass(frozen=True, slots=True)
9296
class SortedTasks:
@@ -139,8 +143,8 @@ async def _triage_changed_tasks(
139143

140144
@dataclass(kw_only=True)
141145
class ScheduledPipelineParams:
142-
scheduler_task: asyncio.Task | None = None
143-
scheduler_waker: asyncio.Event = field(default_factory=asyncio.Event)
146+
scheduler_task: asyncio.Task
147+
scheduler_waker: asyncio.Event
144148

145149
def wake_up(self) -> None:
146150
self.scheduler_waker.set()
@@ -169,31 +173,19 @@ async def restore_scheduling_from_db(self) -> None:
169173
filter_by_state=SCHEDULED_STATES
170174
)
171175

172-
self._scheduled_pipelines |= {
173-
(
174-
run.user_id,
175-
run.project_uuid,
176-
run.iteration,
177-
): ScheduledPipelineParams()
178-
for run in comp_runs
179-
}
180-
181-
for (
182-
user_id,
183-
project_id,
184-
iteration,
185-
), params in self._scheduled_pipelines.items():
186-
self._start_scheduling(params, user_id, project_id, iteration)
187-
188-
async def start_scheduling(self) -> None:
189-
await self.restore_scheduling_from_db()
190-
191-
for (
192-
user_id,
193-
project_id,
194-
iteration,
195-
), params in self._scheduled_pipelines.items():
196-
self._start_scheduling(params, user_id, project_id, iteration)
176+
for run in comp_runs:
177+
task, wake_up_event = self._start_scheduling(
178+
run.user_id, run.project_uuid, run.iteration
179+
)
180+
self._scheduled_pipelines |= {
181+
(
182+
run.user_id,
183+
run.project_uuid,
184+
run.iteration,
185+
): ScheduledPipelineParams(
186+
scheduler_task=task, scheduler_waker=wake_up_event
187+
)
188+
}
197189

198190
async def run_new_pipeline(
199191
self,
@@ -224,9 +216,12 @@ async def run_new_pipeline(
224216
metadata=run_metadata,
225217
use_on_demand_clusters=use_on_demand_clusters,
226218
)
219+
task, wake_up_event = self._start_scheduling(
220+
user_id, project_id, new_run.iteration
221+
)
227222
self._scheduled_pipelines[
228223
(user_id, project_id, new_run.iteration)
229-
] = pipeline_params = ScheduledPipelineParams()
224+
] = ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event)
230225
await publish_project_log(
231226
self.rabbitmq_client,
232227
user_id,
@@ -235,8 +230,6 @@ async def run_new_pipeline(
235230
log_level=logging.INFO,
236231
)
237232

238-
self._start_scheduling(pipeline_params, user_id, project_id, new_run.iteration)
239-
240233
async def stop_pipeline(
241234
self, user_id: UserID, project_id: ProjectID, iteration: int | None = None
242235
) -> None:
@@ -293,11 +286,10 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati
293286

294287
def _start_scheduling(
295288
self,
296-
pipeline_params: ScheduledPipelineParams,
297289
user_id: UserID,
298290
project_id: ProjectID,
299291
iteration: Iteration,
300-
) -> None:
292+
) -> tuple[PipelineSchedulingTask, PipelineSchedulingWakeUpEvent]:
301293
async def _exclusive_safe_schedule_pipeline(
302294
*,
303295
user_id: UserID,
@@ -313,20 +305,22 @@ async def _exclusive_safe_schedule_pipeline(
313305
wake_up_callback=wake_up_callback,
314306
)
315307

316-
pipeline_params.scheduler_task = start_periodic_task(
308+
pipeline_wake_up_event = asyncio.Event()
309+
pipeline_task = start_periodic_task(
317310
functools.partial(
318311
_exclusive_safe_schedule_pipeline,
319312
user_id=user_id,
320313
project_id=project_id,
321314
iteration=iteration,
322-
wake_up_callback=pipeline_params.wake_up,
315+
wake_up_callback=pipeline_wake_up_event.set,
323316
),
324317
interval=_SCHEDULER_INTERVAL,
325318
task_name=_TASK_NAME_TEMPLATE.format(
326319
user_id=user_id, project_id=project_id, iteration=iteration
327320
),
328-
early_wake_up_event=pipeline_params.scheduler_waker,
321+
early_wake_up_event=pipeline_wake_up_event,
329322
)
323+
return pipeline_task, pipeline_wake_up_event
330324

331325
async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph:
332326
comp_pipeline_repo = CompPipelinesRepository.instance(self.db_engine)

0 commit comments

Comments
 (0)