Skip to content

Commit abedadf

Browse files
committed
initialize dask cluster
1 parent ca11504 commit abedadf

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

jupyter_scheduler/extension.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,25 @@ def initialize_settings(self):
9191
if scheduler.task_runner:
9292
loop = asyncio.get_event_loop()
9393
loop.create_task(scheduler.task_runner.start())
94+
95+
async def stop_extension(self):
96+
"""
97+
Public method called by Jupyter Server when the server is stopping.
98+
This calls the cleanup code defined in `self._stop_exception()` inside
99+
an exception handler, as the server halts if this method raises an
100+
exception.
101+
"""
102+
try:
103+
await self._stop_extension()
104+
except Exception as e:
105+
self.log.error("Jupyter Scheduler raised an exception while stopping:")
106+
self.log.exception(e)
107+
108+
async def _stop_extension(self):
109+
"""
110+
Private method that defines the cleanup code to run when the server is
111+
stopping.
112+
"""
113+
if "scheduler" in self.settings:
114+
scheduler: SchedulerApp = self.settings["scheduler"]
115+
await scheduler.stop_extension()

jupyter_scheduler/scheduler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import fsspec
88
import psutil
9+
from dask.distributed import Client as DaskClient
10+
from distributed import LocalCluster
911
from jupyter_core.paths import jupyter_data_dir
1012
from jupyter_server.transutils import _i18n
1113
from jupyter_server.utils import to_os_path
@@ -381,6 +383,12 @@ def get_local_output_path(
381383
else:
382384
return os.path.join(self.root_dir, self.output_directory, output_dir_name)
383385

386+
async def stop_extension(self):
387+
"""
388+
Placeholder method for a cleanup code to run when the server is stopping.
389+
"""
390+
pass
391+
384392

385393
class Scheduler(BaseScheduler):
386394
_db_session = None
@@ -395,6 +403,12 @@ class Scheduler(BaseScheduler):
395403
),
396404
)
397405

406+
dask_cluster_url = Unicode(
407+
allow_none=True,
408+
config=True,
409+
help="URL of the Dask cluster to connect to.",
410+
)
411+
398412
db_url = Unicode(help=_i18n("Scheduler database url"))
399413

400414
task_runner = Instance(allow_none=True, klass="jupyter_scheduler.task_runner.BaseTaskRunner")
@@ -414,6 +428,15 @@ def __init__(
414428
if self.task_runner_class:
415429
self.task_runner = self.task_runner_class(scheduler=self, config=config)
416430

431+
self.dask_client: DaskClient = self._get_dask_client()
432+
433+
def _get_dask_client(self):
434+
"""Creates and configures a Dask client."""
435+
if self.dask_cluster_url:
436+
return DaskClient(self.dask_cluster_url)
437+
cluster = LocalCluster(processes=True)
438+
return DaskClient(cluster)
439+
417440
@property
418441
def db_session(self):
419442
if not self._db_session:
@@ -777,6 +800,13 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
777800

778801
return staging_paths
779802

803+
async def stop_extension(self):
804+
"""
805+
Cleanup code to run when the server is stopping.
806+
"""
807+
if self.dask_client:
808+
await self.dask_client.close()
809+
780810

781811
class ArchivingScheduler(Scheduler):
782812
"""Scheduler that captures all files in output directory in an archive."""

0 commit comments

Comments
 (0)