Skip to content

Commit ca1bc5e

Browse files
committed
removed Subs from dv-2
1 parent 80fe6fe commit ca1bc5e

File tree

5 files changed

+24
-67
lines changed

5 files changed

+24
-67
lines changed

services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import AsyncIterator, Callable
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass
7+
from typing import Any
78

89
import arrow
910
from dask_task_models_library.container_tasks.errors import TaskCancelledError
@@ -35,7 +36,7 @@
3536
clean_task_output_and_log_files_if_invalid,
3637
parse_output_data,
3738
)
38-
from ...utils.dask_client_utils import TaskHandlers
39+
from ...utils.dask_client_utils import TaskHandlers, UnixTimestamp
3940
from ...utils.rabbitmq import (
4041
publish_service_progress,
4142
publish_service_resource_tracking_stopped,
@@ -344,9 +345,11 @@ async def _process_task_result(
344345
optional_stopped=arrow.utcnow().datetime,
345346
)
346347

347-
async def _task_progress_change_handler(self, event: str) -> None:
348+
async def _task_progress_change_handler(
349+
self, event: tuple[UnixTimestamp, Any]
350+
) -> None:
348351
with log_catch(_logger, reraise=False):
349-
task_progress_event = TaskProgressEvent.model_validate_json(event)
352+
task_progress_event = TaskProgressEvent.model_validate_json(event[1])
350353
_logger.debug("received task progress update: %s", task_progress_event)
351354
user_id = task_progress_event.task_owner.user_id
352355
project_id = task_progress_event.task_owner.project_id

services/director-v2/src/simcore_service_director_v2/modules/dask_client.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
99
"""
1010

11-
import asyncio
1211
import logging
1312
import traceback
1413
from collections.abc import Callable
1514
from copy import deepcopy
16-
from dataclasses import dataclass, field
15+
from dataclasses import dataclass
1716
from http.client import HTTPException
1817
from typing import Any, Final, cast
1918

@@ -23,6 +22,7 @@
2322
from common_library.json_serialization import json_dumps
2423
from dask_task_models_library.container_tasks.docker import DockerBasicAuth
2524
from dask_task_models_library.container_tasks.errors import TaskCancelledError
25+
from dask_task_models_library.container_tasks.events import TaskProgressEvent
2626
from dask_task_models_library.container_tasks.io import (
2727
TaskCancelEventName,
2828
TaskInputData,
@@ -53,7 +53,7 @@
5353
from models_library.users import UserID
5454
from pydantic import TypeAdapter, ValidationError
5555
from pydantic.networks import AnyUrl
56-
from servicelib.logging_utils import log_catch
56+
from servicelib.logging_utils import log_catch, log_context
5757
from settings_library.s3 import S3Settings
5858
from simcore_sdk.node_ports_common.exceptions import NodeportsException
5959
from simcore_sdk.node_ports_v2 import FileLinkType
@@ -123,8 +123,6 @@ class DaskClient:
123123
tasks_file_link_type: FileLinkType
124124
cluster_type: ClusterTypeInModel
125125

126-
_subscribed_tasks: list[asyncio.Task] = field(default_factory=list)
127-
128126
@classmethod
129127
async def create(
130128
cls,
@@ -177,24 +175,15 @@ async def create(
177175
raise ValueError(err_msg)
178176

179177
async def delete(self) -> None:
180-
_logger.debug("closing dask client...")
181-
for task in self._subscribed_tasks:
182-
task.cancel()
183-
await asyncio.gather(*self._subscribed_tasks, return_exceptions=True)
184-
await self.backend.close()
185-
_logger.info("dask client properly closed")
178+
with log_context(_logger, logging.INFO, msg="close dask client"):
179+
await self.backend.close()
186180

187181
def register_handlers(self, task_handlers: TaskHandlers) -> None:
188182
_event_consumer_map = [
189-
(self.backend.progress_sub, task_handlers.task_progress_handler),
190-
]
191-
self._subscribed_tasks = [
192-
asyncio.create_task(
193-
dask_utils.dask_sub_consumer_task(dask_sub, handler),
194-
name=f"{dask_sub.name}_dask_sub_consumer_task",
195-
)
196-
for dask_sub, handler in _event_consumer_map
183+
(TaskProgressEvent.topic_name(), task_handlers.task_progress_handler),
197184
]
185+
for topic_name, handler in _event_consumer_map:
186+
self.backend.client.subscribe_topic(topic_name, handler)
198187

199188
async def _publish_in_dask( # noqa: PLR0913 # pylint: disable=too-many-arguments
200189
self,

services/director-v2/src/simcore_service_director_v2/utils/dask.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
import asyncio
21
import collections
32
import logging
4-
from collections.abc import Awaitable, Callable, Coroutine, Generator
5-
from typing import Any, Final, NoReturn, ParamSpec, TypeVar, cast
3+
from collections.abc import Coroutine, Generator
4+
from typing import Any, ParamSpec, TypeVar, cast
65

76
import distributed
87
from common_library.json_serialization import json_dumps
@@ -30,7 +29,6 @@
3029
from models_library.users import UserID
3130
from models_library.wallets import WalletID
3231
from pydantic import AnyUrl, ByteSize, TypeAdapter, ValidationError
33-
from servicelib.logging_utils import log_catch, log_context
3432
from simcore_sdk import node_ports_v2
3533
from simcore_sdk.node_ports_common.exceptions import (
3634
NodeportsException,
@@ -438,40 +436,6 @@ async def clean_task_output_and_log_files_if_invalid(
438436
)
439437

440438

441-
async def _dask_sub_consumer(
442-
dask_sub: distributed.Sub,
443-
handler: Callable[[str], Awaitable[None]],
444-
) -> None:
445-
async for dask_event in dask_sub:
446-
_logger.debug(
447-
"received dask event '%s' of topic %s",
448-
dask_event,
449-
dask_sub.name,
450-
)
451-
await handler(dask_event)
452-
453-
454-
_REST_TIMEOUT_S: Final[int] = 1
455-
456-
457-
async def dask_sub_consumer_task(
458-
dask_sub: distributed.Sub,
459-
handler: Callable[[str], Awaitable[None]],
460-
) -> NoReturn:
461-
while True:
462-
with (
463-
log_catch(_logger, reraise=False),
464-
log_context(
465-
_logger,
466-
level=logging.DEBUG,
467-
msg=f"dask sub task for topic {dask_sub.name}",
468-
),
469-
):
470-
await _dask_sub_consumer(dask_sub, handler)
471-
# we sleep a bit before restarting
472-
await asyncio.sleep(_REST_TIMEOUT_S)
473-
474-
475439
def from_node_reqs_to_dask_resources(
476440
node_reqs: NodeRequirements,
477441
) -> dict[str, int | float]:

services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import socket
44
from collections.abc import Awaitable, Callable
55
from dataclasses import dataclass
6+
from typing import Any, TypeAlias
67

78
import distributed
89
from models_library.clusters import ClusterAuthentication, TLSAuthentication
@@ -11,10 +12,12 @@
1112
from ..core.errors import ConfigurationError
1213
from .dask import wrap_client_async_routine
1314

15+
UnixTimestamp: TypeAlias = float
16+
1417

1518
@dataclass
1619
class TaskHandlers:
17-
task_progress_handler: Callable[[str], Awaitable[None]]
20+
task_progress_handler: Callable[[tuple[UnixTimestamp, Any]], Awaitable[None]]
1821

1922

2023
logger = logging.getLogger(__name__)

services/director-v2/tests/unit/test_modules_dask_client.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,6 @@ async def factory() -> DaskClient:
163163
client.settings
164164
== minimal_app.state.settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND
165165
)
166-
assert not client._subscribed_tasks # noqa: SLF001
167166

168167
assert client.backend.client
169168
scheduler_infos = client.backend.client.scheduler_info() # type: ignore
@@ -181,7 +180,7 @@ async def factory() -> DaskClient:
181180
@pytest.fixture(params=["create_dask_client_from_scheduler"])
182181
async def dask_client(
183182
create_dask_client_from_scheduler: Callable[[], Awaitable[DaskClient]],
184-
request,
183+
request: pytest.FixtureRequest,
185184
) -> DaskClient:
186185
client: DaskClient = await {
187186
"create_dask_client_from_scheduler": create_dask_client_from_scheduler,
@@ -402,7 +401,7 @@ def comp_run_metadata(faker: Faker) -> RunMetadataDict:
402401
return RunMetadataDict(
403402
product_name=faker.pystr(),
404403
simcore_user_agent=faker.pystr(),
405-
) | cast(dict[str, str], faker.pydict(allowed_types=(str,)))
404+
) | cast(RunMetadataDict, faker.pydict(allowed_types=(str,)))
406405

407406

408407
@pytest.fixture
@@ -1102,8 +1101,7 @@ def fake_remote_fct(
11021101
log_file_url: LogFileUploadURL,
11031102
s3_settings: S3Settings | None,
11041103
) -> TaskOutputData:
1105-
progress_pub = distributed.Pub(TaskProgressEvent.topic_name())
1106-
progress_pub.put("my name is progress")
1104+
get_worker().log_event(TaskProgressEvent.topic_name(), "my name is progress")
11071105
# tell the client we are done
11081106
published_event = Event(name=_DASK_START_EVENT)
11091107
published_event.set()
@@ -1147,7 +1145,7 @@ def fake_remote_fct(
11471145
)
11481146
# we should have received data in our TaskHandlers
11491147
fake_task_handlers.task_progress_handler.assert_called_with(
1150-
"my name is progress"
1148+
(mock.ANY, "my name is progress")
11511149
)
11521150
await _assert_wait_for_cb_call(mocked_user_completed_cb)
11531151

0 commit comments

Comments
 (0)