1
1
"""Handle execution backends."""
2
2
import asyncio
3
- import time
4
3
from uuid import uuid4
5
- from .workers import (
6
- SerialWorker ,
7
- ConcurrentFuturesWorker ,
8
- SlurmWorker ,
9
- DaskWorker ,
10
- SGEWorker ,
11
- )
4
+ from .workers import WORKERS
12
5
from .core import is_workflow
13
6
from .helpers import get_open_loop , load_and_run_async
14
7
@@ -35,61 +28,60 @@ def __init__(self, plugin="cf", **kwargs):
35
28
self .loop = get_open_loop ()
36
29
self ._own_loop = not self .loop .is_running ()
37
30
self .plugin = plugin
38
- if self .plugin == "serial" :
39
- self .worker = SerialWorker ()
40
- elif self .plugin == "cf" :
41
- self .worker = ConcurrentFuturesWorker (** kwargs )
42
- elif self .plugin == "slurm" :
43
- self .worker = SlurmWorker (** kwargs )
44
- elif self .plugin == "dask" :
45
- self .worker = DaskWorker (** kwargs )
46
- elif self .plugin == "sge" :
47
- self .worker = SGEWorker (** kwargs )
48
- else :
49
- raise Exception (f"plugin { self .plugin } not available" )
31
+ try :
32
+ self .worker = WORKERS [self .plugin ](** kwargs )
33
+ except KeyError :
34
+ raise NotImplementedError (f"No worker for { self .plugin } " )
50
35
self .worker .loop = self .loop
51
36
52
37
def __call__ (self , runnable , cache_locations = None , rerun = False ):
53
- """Submit ."""
38
+ """Submitter run function ."""
54
39
if cache_locations is not None :
55
40
runnable .cache_locations = cache_locations
56
- # creating all connections and calculating the checksum of the graph before running
57
- if is_workflow (runnable ):
58
- # TODO: no prepare state ?
59
- for nd in runnable .graph .nodes :
60
- runnable .create_connections (nd )
61
- if nd .allow_cache_override :
62
- nd .cache_dir = runnable .cache_dir
63
- if is_workflow (runnable ) and runnable .state is None :
64
- self .loop .run_until_complete (self .submit_workflow (runnable , rerun = rerun ))
65
- else :
66
- self .loop .run_until_complete (self .submit (runnable , wait = True , rerun = rerun ))
67
- if is_workflow (runnable ):
68
- # resetting all connections with LazyFields
69
- runnable ._reset ()
41
+ self .loop .run_until_complete (self .submit_from_call (runnable , rerun ))
70
42
return runnable .result ()
71
43
72
- async def submit_workflow (self , workflow , rerun = False ):
73
- """Distribute or initiate workflow execution."""
74
- if is_workflow (workflow ):
75
- if workflow .plugin and workflow .plugin != self .plugin :
76
- # dj: this is not tested!!! TODO
77
- await self .worker .run_el (workflow , rerun = rerun )
44
+ async def submit_from_call (self , runnable , rerun ):
45
+ """
46
+ This coroutine should only be called once per Submitter call,
47
+ and serves as the bridge between sync/async lands.
48
+
49
+ There are 4 potential paths based on the type of runnable:
50
+ 0) Workflow has a different plugin than a submitter
51
+ 1) Workflow without State
52
+ 2) Task without State
53
+ 3) (Workflow or Task) with State
54
+
55
+ Once Python 3.10 is the minimum, this should probably be refactored into using
56
+ structural pattern matching.
57
+ """
58
+ if is_workflow (runnable ):
59
+ # connect and calculate the checksum of the graph before running
60
+ runnable ._connect_and_propagate_to_tasks (override_task_caches = True )
61
+ # 0
62
+ if runnable .plugin and runnable .plugin != self .plugin :
63
+ # if workflow has a different plugin it's treated as a single element
64
+ await self .worker .run_el (runnable , rerun = rerun )
65
+ # 1
66
+ if runnable .state is None :
67
+ await runnable ._run (self , rerun = rerun )
68
+ # 3
78
69
else :
79
- await workflow ._run (self , rerun = rerun )
80
- else : # could be a tuple with paths to pickle files wiith tasks and inputs
81
- ind , wf_main_pkl , wf_orig = workflow
82
- if wf_orig .plugin and wf_orig .plugin != self .plugin :
83
- # dj: this is not tested!!! TODO
84
- await self .worker .run_el (workflow , rerun = rerun )
70
+ await self .expand_runnable (runnable , wait = True , rerun = rerun )
71
+ runnable ._reset ()
72
+ else :
73
+ # 2
74
+ if runnable .state is None :
75
+ # run_el should always return a coroutine
76
+ await self .worker .run_el (runnable , rerun = rerun )
77
+ # 3
85
78
else :
86
- await load_and_run_async (
87
- task_pkl = wf_main_pkl , ind = ind , submitter = self , rerun = rerun
88
- )
79
+ await self .expand_runnable (runnable , wait = True , rerun = rerun )
80
+ return True
89
81
90
- async def submit (self , runnable , wait = False , rerun = False ):
82
+ async def expand_runnable (self , runnable , wait = False , rerun = False ):
91
83
"""
92
- Coroutine entrypoint for task submission .
84
+ This coroutine handles state expansion .
93
85
94
86
Removes any states from `runnable`. If `wait` is
95
87
set to False (default), aggregates all worker
@@ -110,41 +102,37 @@ async def submit(self, runnable, wait=False, rerun=False):
110
102
Coroutines for :class:`~pydra.engine.core.TaskBase` execution.
111
103
112
104
"""
105
+ if runnable .plugin and runnable .plugin != self .plugin :
106
+ raise NotImplementedError ()
107
+
113
108
futures = set ()
114
- if runnable .state :
115
- runnable .state .prepare_states (runnable .inputs , cont_dim = runnable .cont_dim )
116
- runnable .state .prepare_inputs ()
117
- logger .debug (
118
- f"Expanding { runnable } into { len (runnable .state .states_val )} states"
119
- )
120
- task_pkl = runnable .pickle_task ()
121
-
122
- for sidx in range (len (runnable .state .states_val )):
123
- job_tuple = (sidx , task_pkl , runnable )
124
- if is_workflow (runnable ):
125
- # job has no state anymore
126
- futures .add (self .submit_workflow (job_tuple , rerun = rerun ))
127
- else :
128
- # tasks are submitted to worker for execution
129
- futures .add (self .worker .run_el (job_tuple , rerun = rerun ))
130
- else :
109
+ if runnable .state is None :
110
+ raise Exception ("Only runnables with state should reach here" )
111
+
112
+ task_pkl = await prepare_runnable_with_state (runnable )
113
+
114
+ for sidx in range (len (runnable .state .states_val )):
131
115
if is_workflow (runnable ):
132
- await self ._run_workflow (runnable , rerun = rerun )
116
+ # job has no state anymore
117
+ futures .add (
118
+ # This unpickles and runs workflow - why are we pickling?
119
+ asyncio .create_task (load_and_run_async (task_pkl , sidx , self , rerun ))
120
+ )
133
121
else :
134
- # submit task to worker
135
- futures .add (self .worker .run_el (runnable , rerun = rerun ))
122
+ futures .add (self .worker .run_el ((sidx , task_pkl , runnable ), rerun = rerun ))
136
123
137
124
if wait and futures :
138
- # run coroutines concurrently and wait for execution
139
- # wait until all states complete or error
125
+ # if wait is True, we are at the end of the graph / state expansion.
126
+ # Once the remaining jobs end, we will exit `submit_from_call`
140
127
await asyncio .gather (* futures )
141
128
return
142
129
# pass along futures to be awaited independently
143
130
return futures
144
131
145
- async def _run_workflow (self , wf , rerun = False ):
132
+ async def expand_workflow (self , wf , rerun = False ):
146
133
"""
147
134
Expand and execute a stateless :class:`~pydra.engine.core.Workflow`.
135
+ This method is only reached by `Workflow._run_task`.
148
136
149
137
Parameters
150
138
----------
@@ -157,10 +145,6 @@ async def _run_workflow(self, wf, rerun=False):
157
145
The computed workflow
158
146
159
147
"""
160
- for nd in wf .graph .nodes :
161
- if nd .allow_cache_override :
162
- nd .cache_dir = wf .cache_dir
163
-
164
148
# creating a copy of the graph that will be modified
165
149
# the copy contains new lists with original runnable objects
166
150
graph_copy = wf .graph .copy ()
@@ -180,7 +164,8 @@ async def _run_workflow(self, wf, rerun=False):
180
164
while not tasks and graph_copy .nodes :
181
165
tasks , follow_err = get_runnable_tasks (graph_copy )
182
166
ii += 1
183
- time .sleep (1 )
167
+ # don't block the event loop!
168
+ await asyncio .sleep (1 )
184
169
if ii > 60 :
185
170
raise Exception (
186
171
"graph is not empty, but not able to get more tasks "
@@ -191,11 +176,15 @@ async def _run_workflow(self, wf, rerun=False):
191
176
logger .debug (f"Retrieving inputs for { task } " )
192
177
# TODO: add state idx to retrieve values to reduce waiting
193
178
task .inputs .retrieve_values (wf )
194
- if is_workflow (task ) and not task .state :
195
- await self .submit_workflow (task , rerun = rerun )
196
- else :
197
- for fut in await self .submit (task , rerun = rerun ):
179
+ if task .state :
180
+ for fut in await self .expand_runnable (task , rerun = rerun ):
198
181
task_futures .add (fut )
182
+ # expand that workflow
183
+ elif is_workflow (task ):
184
+ await task ._run (self , rerun = rerun )
185
+ # single task
186
+ else :
187
+ task_futures .add (self .worker .run_el (task , rerun = rerun ))
199
188
task_futures = await self .worker .fetch_finished (task_futures )
200
189
tasks , follow_err = get_runnable_tasks (graph_copy )
201
190
# updating tasks_errored
@@ -285,3 +274,10 @@ def is_runnable(graph, obj):
285
274
graph .remove_nodes_connections (nd )
286
275
287
276
return True
277
+
278
+
279
+ async def prepare_runnable_with_state (runnable ):
280
+ runnable .state .prepare_states (runnable .inputs , cont_dim = runnable .cont_dim )
281
+ runnable .state .prepare_inputs ()
282
+ logger .debug (f"Expanding { runnable } into { len (runnable .state .states_val )} states" )
283
+ return runnable .pickle_task ()
0 commit comments