|
38 | 38 | ) |
39 | 39 | from ...utils.dask_client_utils import TaskHandlers, UnixTimestamp |
40 | 40 | from ...utils.rabbitmq import ( |
41 | | - publish_service_progress, |
42 | 41 | publish_service_resource_tracking_stopped, |
43 | 42 | publish_service_stopped_metrics, |
44 | 43 | ) |
45 | 44 | from ..clusters_keeper import get_or_create_on_demand_cluster |
46 | 45 | from ..dask_client import DaskClient, PublishedComputationTask |
47 | 46 | from ..dask_clients_pool import DaskClientsPool |
48 | | -from ..db.repositories.comp_runs import CompRunsRepository |
49 | 47 | from ..db.repositories.comp_tasks import CompTasksRepository |
50 | 48 | from ._scheduler_base import BaseCompScheduler |
51 | 49 |
|
@@ -159,27 +157,28 @@ async def _get_tasks_status( |
159 | 157 | use_on_demand_clusters=comp_run.use_on_demand_clusters, |
160 | 158 | run_metadata=comp_run.metadata, |
161 | 159 | ) as client: |
162 | | - tasks_statuses = await client.get_tasks_status( |
163 | | - [f"{t.job_id}" for t in tasks] |
164 | | - ) |
165 | | - # process dask states |
166 | | - running_states: list[RunningState] = [] |
167 | | - for dask_task_state, task in zip(tasks_statuses, tasks, strict=True): |
168 | | - if dask_task_state is DaskClientTaskState.PENDING_OR_STARTED: |
169 | | - running_states += [ |
170 | | - ( |
171 | | - RunningState.STARTED |
172 | | - if task.progress is not None |
173 | | - else RunningState.PENDING |
174 | | - ) |
175 | | - ] |
176 | | - else: |
177 | | - running_states += [ |
178 | | - _DASK_CLIENT_TASK_STATE_TO_RUNNING_STATE_MAP.get( |
179 | | - dask_task_state, RunningState.UNKNOWN |
180 | | - ) |
181 | | - ] |
182 | | - return running_states |
| 160 | + return await client.get_tasks_status2([f"{t.job_id}" for t in tasks]) |
| 161 | + # tasks_statuses = await client.get_tasks_status( |
| 162 | + # [f"{t.job_id}" for t in tasks] |
| 163 | + # ) |
| 164 | + # # process dask states |
| 165 | + # running_states: list[RunningState] = [] |
| 166 | + # for dask_task_state, task in zip(tasks_statuses, tasks, strict=True): |
| 167 | + # if dask_task_state is DaskClientTaskState.PENDING_OR_STARTED: |
| 168 | + # running_states += [ |
| 169 | + # ( |
| 170 | + # RunningState.STARTED |
| 171 | + # if task.progress is not None |
| 172 | + # else RunningState.PENDING |
| 173 | + # ) |
| 174 | + # ] |
| 175 | + # else: |
| 176 | + # running_states += [ |
| 177 | + # _DASK_CLIENT_TASK_STATE_TO_RUNNING_STATE_MAP.get( |
| 178 | + # dask_task_state, RunningState.UNKNOWN |
| 179 | + # ) |
| 180 | + # ] |
| 181 | + # return running_states |
183 | 182 |
|
184 | 183 | except ComputationalBackendOnDemandNotReadyError: |
185 | 184 | _logger.info("The on demand computational backend is not ready yet...") |
@@ -351,30 +350,30 @@ async def _task_progress_change_handler( |
351 | 350 | with log_catch(_logger, reraise=False): |
352 | 351 | task_progress_event = TaskProgressEvent.model_validate_json(event[1]) |
353 | 352 | _logger.debug("received task progress update: %s", task_progress_event) |
354 | | - user_id = task_progress_event.task_owner.user_id |
355 | | - project_id = task_progress_event.task_owner.project_id |
356 | | - node_id = task_progress_event.task_owner.node_id |
357 | | - comp_tasks_repo = CompTasksRepository(self.db_engine) |
358 | | - task = await comp_tasks_repo.get_task(project_id, node_id) |
359 | | - if task.progress is None: |
360 | | - task.state = RunningState.STARTED |
361 | | - task.progress = task_progress_event.progress |
362 | | - run = await CompRunsRepository(self.db_engine).get(user_id, project_id) |
363 | | - await self._process_started_tasks( |
364 | | - [task], |
365 | | - user_id=user_id, |
366 | | - project_id=project_id, |
367 | | - iteration=run.iteration, |
368 | | - run_metadata=run.metadata, |
369 | | - ) |
370 | | - else: |
371 | | - await comp_tasks_repo.update_project_task_progress( |
372 | | - project_id, node_id, task_progress_event.progress |
373 | | - ) |
374 | | - await publish_service_progress( |
375 | | - self.rabbitmq_client, |
376 | | - user_id=user_id, |
377 | | - project_id=project_id, |
378 | | - node_id=node_id, |
379 | | - progress=task_progress_event.progress, |
380 | | - ) |
| 353 | + # user_id = task_progress_event.task_owner.user_id |
| 354 | + # project_id = task_progress_event.task_owner.project_id |
| 355 | + # node_id = task_progress_event.task_owner.node_id |
| 356 | + # comp_tasks_repo = CompTasksRepository(self.db_engine) |
| 357 | + # task = await comp_tasks_repo.get_task(project_id, node_id) |
| 358 | + # if task.progress is None: |
| 359 | + # task.state = RunningState.STARTED |
| 360 | + # task.progress = task_progress_event.progress |
| 361 | + # run = await CompRunsRepository(self.db_engine).get(user_id, project_id) |
| 362 | + # await self._process_started_tasks( |
| 363 | + # [task], |
| 364 | + # user_id=user_id, |
| 365 | + # project_id=project_id, |
| 366 | + # iteration=run.iteration, |
| 367 | + # run_metadata=run.metadata, |
| 368 | + # ) |
| 369 | + # else: |
| 370 | + # await comp_tasks_repo.update_project_task_progress( |
| 371 | + # project_id, node_id, task_progress_event.progress |
| 372 | + # ) |
| 373 | + # await publish_service_progress( |
| 374 | + # self.rabbitmq_client, |
| 375 | + # user_id=user_id, |
| 376 | + # project_id=project_id, |
| 377 | + # node_id=node_id, |
| 378 | + # progress=task_progress_event.progress, |
| 379 | + # ) |
0 commit comments