Skip to content

Commit dd38b8c

Browse files
committed
add dask client, use it for scheduler.create_job
1 parent 8eaba42 commit dd38b8c

File tree

3 files changed

+46
-22
lines changed

3 files changed

+46
-22
lines changed

jupyter_scheduler/extension.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import asyncio
2-
1+
from dask.distributed import Client as DaskClient
32
from jupyter_core.paths import jupyter_data_dir
43
from jupyter_server.extension.application import ExtensionApp
54
from jupyter_server.transutils import _i18n
@@ -73,11 +72,15 @@ def initialize_settings(self):
7372

7473
environments_manager = self.environment_manager_class()
7574

75+
asyncio_loop = self.serverapp.io_loop.asyncio_loop
76+
dask_client_future = asyncio_loop.create_task(self._get_dask_client())
77+
7678
scheduler = self.scheduler_class(
7779
root_dir=self.serverapp.root_dir,
7880
environments_manager=environments_manager,
7981
db_url=self.db_url,
8082
config=self.config,
83+
dask_client_future=dask_client_future,
8184
)
8285

8386
job_files_manager = self.job_files_manager_class(scheduler=scheduler)
@@ -86,8 +89,28 @@ def initialize_settings(self):
8689
environments_manager=environments_manager,
8790
scheduler=scheduler,
8891
job_files_manager=job_files_manager,
92+
dask_client_future=dask_client_future,
8993
)
9094

9195
if scheduler.task_runner:
92-
loop = asyncio.get_event_loop()
93-
loop.create_task(scheduler.task_runner.start())
96+
asyncio_loop.create_task(scheduler.task_runner.start())
97+
98+
async def _get_dask_client(self):
99+
"""Creates and configures a Dask client."""
100+
return DaskClient(processes=False, asynchronous=True)
101+
102+
async def stop_extension(self):
103+
"""Called by the Jupyter Server when stopping to cleanup resources."""
104+
try:
105+
await self._stop_extension()
106+
except Exception as e:
107+
self.log.error("Error while stopping Jupyter Scheduler:")
108+
self.log.exception(e)
109+
110+
async def _stop_extension(self):
111+
"""Closes the Dask client if it exists."""
112+
if "dask_client_future" in self.settings:
113+
dask_client: DaskClient = await self.settings["dask_client_future"]
114+
self.log.info("Closing Dask client.")
115+
await dask_client.close()
116+
self.log.info("Dask client closed.")

jupyter_scheduler/scheduler.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import multiprocessing as mp
21
import os
32
import random
43
import shutil
5-
from typing import Dict, List, Optional, Type, Union
4+
from typing import Awaitable, Dict, List, Optional, Type, Union
65

76
import fsspec
87
import psutil
8+
from dask.distributed import Client as DaskClient
99
from jupyter_core.paths import jupyter_data_dir
1010
from jupyter_server.transutils import _i18n
1111
from jupyter_server.utils import to_os_path
@@ -96,11 +96,17 @@ def _default_staging_path(self):
9696
)
9797

9898
def __init__(
99-
self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs
99+
self,
100+
root_dir: str,
101+
environments_manager: Type[EnvironmentManager],
102+
dask_client_future: Awaitable[DaskClient],
103+
config=None,
104+
**kwargs,
100105
):
101106
super().__init__(config=config, **kwargs)
102107
self.root_dir = root_dir
103108
self.environments_manager = environments_manager
109+
self.dask_client_future = dask_client_future
104110

105111
def create_job(self, model: CreateJob) -> str:
106112
"""Creates a new job record, may trigger execution of the job.
@@ -437,7 +443,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]:
437443
destination_dir=staging_dir,
438444
)
439445

440-
def create_job(self, model: CreateJob) -> str:
446+
async def create_job(self, model: CreateJob) -> str:
441447
if not model.job_definition_id and not self.file_exists(model.input_uri):
442448
raise InputUriError(model.input_uri)
443449

@@ -478,25 +484,17 @@ def create_job(self, model: CreateJob) -> str:
478484
else:
479485
self.copy_input_file(model.input_uri, staging_paths["input"])
480486

481-
# The MP context forces new processes to not be forked on Linux.
482-
# This is necessary because `asyncio.get_event_loop()` is bugged in
483-
# forked processes in Python versions below 3.12. This method is
484-
# called by `jupyter_core` by `nbconvert` in the default executor.
485-
#
486-
# See: https://github.com/python/cpython/issues/66285
487-
# See also: https://github.com/jupyter/jupyter_core/pull/362
488-
mp_ctx = mp.get_context("spawn")
489-
p = mp_ctx.Process(
490-
target=self.execution_manager_class(
487+
dask_client: DaskClient = await self.dask_client_future
488+
future = dask_client.submit(
489+
self.execution_manager_class(
491490
job_id=job.job_id,
492491
staging_paths=staging_paths,
493492
root_dir=self.root_dir,
494493
db_url=self.db_url,
495494
).process
496495
)
497-
p.start()
498496

499-
job.pid = p.pid
497+
job.pid = future.key
500498
session.commit()
501499

502500
job_id = job.job_id
@@ -749,14 +747,16 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit
749747

750748
return list_response
751749

752-
def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition):
750+
async def create_job_from_definition(
751+
self, job_definition_id: str, model: CreateJobFromDefinition
752+
):
753753
job_id = None
754754
definition = self.get_job_definition(job_definition_id)
755755
if definition:
756756
input_uri = self.get_staging_paths(definition)["input"]
757757
attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True)
758758
attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri}
759-
job_id = self.create_job(CreateJob(**attributes))
759+
job_id = await self.create_job(CreateJob(**attributes))
760760

761761
return job_id
762762

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
"pydantic>=1.10,<3",
3636
"sqlalchemy>=2.0,<3",
3737
"croniter~=1.4",
38+
"dask[distributed]",
3839
"pytz==2023.3",
3940
"fsspec==2023.6.0",
4041
"psutil~=5.9"

0 commit comments

Comments
 (0)