1
1
import os
2
2
import random
3
3
import tarfile
4
- from multiprocessing import Process
5
- from typing import Dict , List , Optional , Type
4
+ from typing import Awaitable , Dict , List , Optional , Type
6
5
7
6
import fsspec
7
+ from dask .distributed import Client as DaskClient
8
8
from jupyter_server .utils import ensure_async
9
9
10
10
from jupyter_scheduler .exceptions import SchedulerError
14
14
class JobFilesManager :
15
15
scheduler = None
16
16
17
- def __init__ (self , scheduler : Type [BaseScheduler ]):
17
+ def __init__ (
18
+ self ,
19
+ scheduler : Type [BaseScheduler ],
20
+ dask_client_future : Awaitable [DaskClient ],
21
+ ):
18
22
self .scheduler = scheduler
23
+ self .dask_client_future = dask_client_future
19
24
20
25
async def copy_from_staging (self , job_id : str , redownload : Optional [bool ] = False ):
21
26
job = await ensure_async (self .scheduler .get_job (job_id , False ))
22
27
staging_paths = await ensure_async (self .scheduler .get_staging_paths (job ))
23
28
output_filenames = self .scheduler .get_job_filenames (job )
24
29
output_dir = self .scheduler .get_local_output_path (model = job , root_dir_relative = True )
25
30
26
- p = Process (
27
- target = Downloader (
31
+ dask_client : DaskClient = await self .dask_client_future
32
+ dask_client .submit (
33
+ Downloader (
28
34
output_formats = job .output_formats ,
29
35
output_filenames = output_filenames ,
30
36
staging_paths = staging_paths ,
@@ -33,7 +39,6 @@ async def copy_from_staging(self, job_id: str, redownload: Optional[bool] = Fals
33
39
include_staging_files = job .package_input_folder ,
34
40
).download
35
41
)
36
- p .start ()
37
42
38
43
39
44
class Downloader :
0 commit comments