|
1 |
| -import multiprocessing as mp |
2 | 1 | import os
|
3 | 2 | import random
|
4 | 3 | import shutil
|
5 |
| -from typing import Dict, List, Optional, Type, Union |
| 4 | +from typing import Awaitable, Dict, List, Optional, Type, Union |
6 | 5 |
|
7 | 6 | import fsspec
|
8 | 7 | import psutil
|
| 8 | +from dask.distributed import Client as DaskClient |
9 | 9 | from jupyter_core.paths import jupyter_data_dir
|
10 | 10 | from jupyter_server.transutils import _i18n
|
11 | 11 | from jupyter_server.utils import to_os_path
|
@@ -96,11 +96,17 @@ def _default_staging_path(self):
|
96 | 96 | )
|
97 | 97 |
|
98 | 98 | def __init__(
|
99 |
| - self, root_dir: str, environments_manager: Type[EnvironmentManager], config=None, **kwargs |
| 99 | + self, |
| 100 | + root_dir: str, |
| 101 | + environments_manager: Type[EnvironmentManager], |
| 102 | + dask_client_future: Awaitable[DaskClient], |
| 103 | + config=None, |
| 104 | + **kwargs, |
100 | 105 | ):
|
101 | 106 | super().__init__(config=config, **kwargs)
|
102 | 107 | self.root_dir = root_dir
|
103 | 108 | self.environments_manager = environments_manager
|
| 109 | + self.dask_client_future = dask_client_future |
104 | 110 |
|
105 | 111 | def create_job(self, model: CreateJob) -> str:
|
106 | 112 | """Creates a new job record, may trigger execution of the job.
|
@@ -437,7 +443,7 @@ def copy_input_folder(self, input_uri: str, nb_copy_to_path: str) -> List[str]:
|
437 | 443 | destination_dir=staging_dir,
|
438 | 444 | )
|
439 | 445 |
|
440 |
| - def create_job(self, model: CreateJob) -> str: |
| 446 | + async def create_job(self, model: CreateJob) -> str: |
441 | 447 | if not model.job_definition_id and not self.file_exists(model.input_uri):
|
442 | 448 | raise InputUriError(model.input_uri)
|
443 | 449 |
|
@@ -478,25 +484,17 @@ def create_job(self, model: CreateJob) -> str:
|
478 | 484 | else:
|
479 | 485 | self.copy_input_file(model.input_uri, staging_paths["input"])
|
480 | 486 |
|
481 |
| - # The MP context forces new processes to not be forked on Linux. |
482 |
| - # This is necessary because `asyncio.get_event_loop()` is bugged in |
483 |
| - # forked processes in Python versions below 3.12. This method is |
484 |
| - # called by `jupyter_core` by `nbconvert` in the default executor. |
485 |
| - # |
486 |
| - # See: https://github.com/python/cpython/issues/66285 |
487 |
| - # See also: https://github.com/jupyter/jupyter_core/pull/362 |
488 |
| - mp_ctx = mp.get_context("spawn") |
489 |
| - p = mp_ctx.Process( |
490 |
| - target=self.execution_manager_class( |
| 487 | + dask_client: DaskClient = await self.dask_client_future |
| 488 | + future = dask_client.submit( |
| 489 | + self.execution_manager_class( |
491 | 490 | job_id=job.job_id,
|
492 | 491 | staging_paths=staging_paths,
|
493 | 492 | root_dir=self.root_dir,
|
494 | 493 | db_url=self.db_url,
|
495 | 494 | ).process
|
496 | 495 | )
|
497 |
| - p.start() |
498 | 496 |
|
499 |
| - job.pid = p.pid |
| 497 | + job.pid = future.key |
500 | 498 | session.commit()
|
501 | 499 |
|
502 | 500 | job_id = job.job_id
|
@@ -749,14 +747,16 @@ def list_job_definitions(self, query: ListJobDefinitionsQuery) -> ListJobDefinit
|
749 | 747 |
|
750 | 748 | return list_response
|
751 | 749 |
|
752 |
| - def create_job_from_definition(self, job_definition_id: str, model: CreateJobFromDefinition): |
| 750 | + async def create_job_from_definition( |
| 751 | + self, job_definition_id: str, model: CreateJobFromDefinition |
| 752 | + ): |
753 | 753 | job_id = None
|
754 | 754 | definition = self.get_job_definition(job_definition_id)
|
755 | 755 | if definition:
|
756 | 756 | input_uri = self.get_staging_paths(definition)["input"]
|
757 | 757 | attributes = definition.dict(exclude={"schedule", "timezone"}, exclude_none=True)
|
758 | 758 | attributes = {**attributes, **model.dict(exclude_none=True), "input_uri": input_uri}
|
759 |
| - job_id = self.create_job(CreateJob(**attributes)) |
| 759 | + job_id = await self.create_job(CreateJob(**attributes)) |
760 | 760 |
|
761 | 761 | return job_id
|
762 | 762 |
|
|
0 commit comments