11import typing
22from pathlib import Path
3+ from typing import Any
34
45import dask .distributed
56
@@ -9,6 +10,7 @@ class DummyFuture(dask.distributed.Future):
910 A class that mimics a distributed Future, the outcome of
1011 performing submit on a distributed client.
1112 """
13+
1214 def __init__ (self , result : typing .Any ) -> None :
1315 self ._result = result # type: typing.Any
1416
@@ -33,13 +35,24 @@ class SingleThreadedClient(dask.distributed.Client):
3335 A class to Mock the Distributed Client class, in case
3436 Auto-Sklearn is meant to run in the current Thread.
3537 """
38+
3639 def __init__ (self ) -> None :
3740
3841 # Raise a not implemented error if using a method from Client
39- implemented_methods = ['submit' , 'close' , 'shutdown' , 'write_scheduler_file' ,
40- '_get_scheduler_info' , 'nthreads' ]
41- method_list = [func for func in dir (dask .distributed .Client ) if callable (
42- getattr (dask .distributed .Client , func )) and not func .startswith ('__' )]
42+ implemented_methods = [
43+ "submit" ,
44+ "close" ,
45+ "shutdown" ,
46+ "write_scheduler_file" ,
47+ "_get_scheduler_info" ,
48+ "nthreads" ,
49+ ]
50+ method_list = [
51+ func
52+ for func in dir (dask .distributed .Client )
53+ if callable (getattr (dask .distributed .Client , func ))
54+ and not func .startswith ("__" )
55+ ]
4356 for method in method_list :
4457 if method in implemented_methods :
4558 continue
@@ -54,8 +67,24 @@ def submit(
5467 func : typing .Callable ,
5568 * args : typing .List ,
5669 priority : int = 0 ,
57- ** kwargs : typing .Dict ,
70+ key : Any = None ,
71+ workers : Any = None ,
72+ resources : Any = None ,
73+ retries : Any = None ,
74+ fifo_timeout : Any = "100 ms" ,
75+ allow_other_workers : Any = False ,
76+ actor : Any = False ,
77+ actors : Any = False ,
78+ pure : Any = None ,
79+ ** kwargs : Any ,
5880 ) -> typing .Any :
81+ """
82+ Note
83+ ----
84+ The keyword arguments caught in `dask.distributed.Client` need to
85+ be specified here so they don't get passed in as ``**kwargs`` to the
86+ ``func``.
87+ """
5988 return DummyFuture (func (* args , ** kwargs ))
6089
6190 def close (self ) -> None :
@@ -70,17 +99,17 @@ def write_scheduler_file(self, scheduler_file: str) -> None:
7099
71100 def _get_scheduler_info (self ) -> typing .Dict :
72101 return {
73- ' workers' : [' 127.0.0.1' ],
74- ' type' : ' Scheduler' ,
102+ " workers" : [" 127.0.0.1" ],
103+ " type" : " Scheduler" ,
75104 }
76105
77106 def nthreads (self ) -> typing .Dict :
78107 return {
79- ' 127.0.0.1' : 1 ,
108+ " 127.0.0.1" : 1 ,
80109 }
81110
82111 def __repr__ (self ) -> str :
83- return ' SingleThreadedClient()'
112+ return " SingleThreadedClient()"
84113
85114 def __del__ (self ) -> None :
86115 pass
0 commit comments