|
44 | 44 | from ..clusters_keeper import get_or_create_on_demand_cluster |
45 | 45 | from ..dask_client import DaskClient, PublishedComputationTask |
46 | 46 | from ..dask_clients_pool import DaskClientsPool |
| 47 | +from ..db.repositories.comp_runs import ( |
| 48 | + CompRunsRepository, |
| 49 | +) |
47 | 50 | from ..db.repositories.comp_tasks import CompTasksRepository |
48 | 51 | from ._scheduler_base import BaseCompScheduler |
49 | 52 |
|
@@ -373,30 +376,30 @@ async def _task_progress_change_handler( |
373 | 376 | with log_catch(_logger, reraise=False): |
374 | 377 | task_progress_event = TaskProgressEvent.model_validate_json(event[1]) |
375 | 378 | _logger.debug("received task progress update: %s", task_progress_event) |
376 | | - # user_id = task_progress_event.task_owner.user_id |
377 | | - # project_id = task_progress_event.task_owner.project_id |
378 | | - # node_id = task_progress_event.task_owner.node_id |
379 | | - # comp_tasks_repo = CompTasksRepository(self.db_engine) |
380 | | - # task = await comp_tasks_repo.get_task(project_id, node_id) |
381 | | - # if task.progress is None: |
382 | | - # task.state = RunningState.STARTED |
383 | | - # task.progress = task_progress_event.progress |
384 | | - # run = await CompRunsRepository(self.db_engine).get(user_id, project_id) |
385 | | - # await self._process_started_tasks( |
386 | | - # [task], |
387 | | - # user_id=user_id, |
388 | | - # project_id=project_id, |
389 | | - # iteration=run.iteration, |
390 | | - # run_metadata=run.metadata, |
391 | | - # ) |
392 | | - # else: |
393 | | - # await comp_tasks_repo.update_project_task_progress( |
394 | | - # project_id, node_id, task_progress_event.progress |
395 | | - # ) |
396 | | - # await publish_service_progress( |
397 | | - # self.rabbitmq_client, |
398 | | - # user_id=user_id, |
399 | | - # project_id=project_id, |
400 | | - # node_id=node_id, |
401 | | - # progress=task_progress_event.progress, |
402 | | - # ) |
| 379 | + user_id = task_progress_event.task_owner.user_id |
| 380 | + project_id = task_progress_event.task_owner.project_id |
| 381 | + node_id = task_progress_event.task_owner.node_id |
| 382 | + comp_tasks_repo = CompTasksRepository(self.db_engine) |
| 383 | + task = await comp_tasks_repo.get_task(project_id, node_id) |
| 384 | + if task.progress is None: |
| 385 | + task.state = RunningState.STARTED |
| 386 | + task.progress = task_progress_event.progress |
| 387 | + run = await CompRunsRepository(self.db_engine).get(user_id, project_id) |
| 388 | + await self._process_started_tasks( |
| 389 | + [task], |
| 390 | + user_id=user_id, |
| 391 | + project_id=project_id, |
| 392 | + iteration=run.iteration, |
| 393 | + run_metadata=run.metadata, |
| 394 | + ) |
| 395 | + else: |
| 396 | + await comp_tasks_repo.update_project_task_progress( |
| 397 | + project_id, node_id, task_progress_event.progress |
| 398 | + ) |
| 399 | + await publish_service_progress( |
| 400 | + self.rabbitmq_client, |
| 401 | + user_id=user_id, |
| 402 | + project_id=project_id, |
| 403 | + node_id=node_id, |
| 404 | + progress=task_progress_event.progress, |
| 405 | + ) |
0 commit comments