1414
1515from dask .core import istask , ishashable
1616
17- from typing import List , Tuple
17+ from typing import List , Tuple , Union
1818from .utils import reduce
1919from ...remote import spawn
20+ from ...deploy .oscar .session import execute
2021
2122
22- def mars_scheduler (dsk : dict , keys : List [List [str ]]):
23+ def mars_scheduler (dsk : dict , keys : Union [ List [List [ str ]], List [str ]]):
2324 """
2425 A Dask-Mars scheduler
2526
@@ -30,22 +31,29 @@ def mars_scheduler(dsk: dict, keys: List[List[str]]):
3031 ----------
3132 dsk: Dict
3233 Dask graph, represented as a task DAG dictionary.
33- keys: List[List[str]]
34- 2d- list of Dask graph keys whose values we wish to compute and return.
34+ keys: Union[ List[List[str]], List[str]]
35+ 1d or 2d list of Dask graph keys whose values we wish to compute and return.
3536
3637 Returns
3738 -------
3839 Object
39- Computed values corresponding to the provided keys.
40+ Computed values corresponding to the provided keys with same dimension .
4041 """
41- res = reduce (mars_dask_get (dsk , keys )).execute ().fetch ()
42- if not isinstance (res , List ):
43- return [[res ]]
44- else :
45- return res
4642
43+ if isinstance (keys , List ) and not isinstance (keys [0 ], List ): # 1d keys
44+ task = execute (mars_dask_get (dsk , keys ))
45+ if not isinstance (task , List ):
46+ task = [task ]
47+ return map (lambda x : x .fetch (), task )
48+ else : # 2d keys
49+ res = execute (reduce (mars_dask_get (dsk , keys ))).fetch ()
50+ if not isinstance (res , List ):
51+ return [[res ]]
52+ else :
53+ return res
4754
48- def mars_dask_get (dsk : dict , keys : List [List ]):
55+
56+ def mars_dask_get (dsk : dict , keys : Union [List [List [str ]], List [str ]]):
4957 """
5058 A Dask-Mars convert function. This function will send the dask graph layers
5159 to Mars Remote API, generating mars objects correspond to the provided keys.
@@ -54,21 +62,21 @@ def mars_dask_get(dsk: dict, keys: List[List]):
5462 ----------
5563 dsk: Dict
5664 Dask graph, represented as a task DAG dictionary.
57- keys: List[List[str]]
58- 2d- list of Dask graph keys whose values we wish to compute and return.
65+ keys: Union[ List[List[str]], List[str]]
66+ 1d or 2d list of Dask graph keys whose values we wish to compute and return.
5967
6068 Returns
6169 -------
6270 Object
63- Spawned mars objects corresponding to the provided keys.
71+ Spawned mars objects corresponding to the provided keys with same dimension .
6472 """
6573
6674 def _get_arg (a ):
6775 # if arg contains layer index or callable objs, handle it
6876 if ishashable (a ) and a in dsk .keys ():
6977 while ishashable (a ) and a in dsk .keys ():
7078 a = dsk [a ]
71- return _execute_task (a )
79+ return _spawn_task (a )
7280 elif not isinstance (a , str ) and hasattr (a , "__getitem__" ):
7381 if istask (
7482 a
@@ -80,9 +88,14 @@ def _get_arg(a):
8088 return type (a )(_get_arg (i ) for i in a )
8189 return a
8290
83- def _execute_task (task : tuple ):
91+ def _spawn_task (task : tuple ):
8492 if not istask (task ):
8593 return _get_arg (task )
8694 return spawn (task [0 ], args = tuple (_get_arg (a ) for a in task [1 :]))
8795
88- return [[_execute_task (dsk [k ]) for k in keys_d ] for keys_d in keys ]
96+ return [
97+ [_spawn_task (dsk [k ]) for k in keys_d ]
98+ if isinstance (keys_d , List )
99+ else _spawn_task (dsk [keys_d ])
100+ for keys_d in keys
101+ ]
0 commit comments