2
2
import asyncio
3
3
import time
4
4
from uuid import uuid4
5
- from .workers import (
6
- SerialWorker ,
7
- ConcurrentFuturesWorker ,
8
- SlurmWorker ,
9
- DaskWorker ,
10
- SGEWorker ,
11
- )
5
+ from .workers import WORKERS
12
6
from .core import is_workflow
13
7
from .helpers import get_open_loop , load_and_run_async
14
8
@@ -35,61 +29,57 @@ def __init__(self, plugin="cf", **kwargs):
35
29
self .loop = get_open_loop ()
36
30
self ._own_loop = not self .loop .is_running ()
37
31
self .plugin = plugin
38
- if self .plugin == "serial" :
39
- self .worker = SerialWorker ()
40
- elif self .plugin == "cf" :
41
- self .worker = ConcurrentFuturesWorker (** kwargs )
42
- elif self .plugin == "slurm" :
43
- self .worker = SlurmWorker (** kwargs )
44
- elif self .plugin == "dask" :
45
- self .worker = DaskWorker (** kwargs )
46
- elif self .plugin == "sge" :
47
- self .worker = SGEWorker (** kwargs )
48
- else :
49
- raise Exception (f"plugin { self .plugin } not available" )
32
+ try :
33
+ self .worker = WORKERS [self .plugin ](** kwargs )
34
+ except KeyError :
35
+ raise NotImplementedError (f"No worker for { self .plugin } " )
50
36
self .worker .loop = self .loop
51
37
52
38
def __call__ (self , runnable , cache_locations = None , rerun = False ):
53
- """Submit ."""
39
+ """Submitter run function ."""
54
40
if cache_locations is not None :
55
41
runnable .cache_locations = cache_locations
56
- # creating all connections and calculating the checksum of the graph before running
57
- if is_workflow (runnable ):
58
- # TODO: no prepare state ?
59
- for nd in runnable .graph .nodes :
60
- runnable .create_connections (nd )
61
- if nd .allow_cache_override :
62
- nd .cache_dir = runnable .cache_dir
63
- if is_workflow (runnable ) and runnable .state is None :
64
- self .loop .run_until_complete (self .submit_workflow (runnable , rerun = rerun ))
65
- else :
66
- self .loop .run_until_complete (self .submit (runnable , wait = True , rerun = rerun ))
67
- if is_workflow (runnable ):
68
- # resetting all connections with LazyFields
69
- runnable ._reset ()
42
+ self .loop .run_until_complete (self .submit_from_call (runnable , rerun ))
70
43
return runnable .result ()
71
44
72
- async def submit_workflow (self , workflow , rerun = False ):
73
- """Distribute or initiate workflow execution."""
74
- if is_workflow (workflow ):
75
- if workflow .plugin and workflow .plugin != self .plugin :
76
- # dj: this is not tested!!! TODO
77
- await self .worker .run_el (workflow , rerun = rerun )
45
+ async def submit_from_call (self , runnable , rerun ):
46
+ """
47
+ This coroutine should only be called once per Submitter call,
48
+ and serves as the bridge between sync/async lands.
49
+
50
+ There are 3 potential paths based on the type of runnable:
51
+
52
+ 1) Workflow without State
53
+ 2) Task without State
54
+ 3) (Workflow or Task) with State
55
+
56
+ Once Python 3.10 is the minimum, this should probably be refactored into using
57
+ structural pattern matching.
58
+ """
59
+ if is_workflow (runnable ):
60
+ # connect and calculate the checksum of the graph before running
61
+ runnable ._connect_and_propagate_to_tasks (override_task_caches = True )
62
+ # 1
63
+ if runnable .state is None :
64
+ await runnable ._run (self , rerun = rerun )
65
+ # 3
78
66
else :
79
- await workflow ._run (self , rerun = rerun )
80
- else : # could be a tuple with paths to pickle files wiith tasks and inputs
81
- ind , wf_main_pkl , wf_orig = workflow
82
- if wf_orig .plugin and wf_orig .plugin != self .plugin :
83
- # dj: this is not tested!!! TODO
84
- await self .worker .run_el (workflow , rerun = rerun )
67
+ await self .expand_runnable (runnable , wait = True )
68
+ runnable ._reset ()
69
+ else :
70
+ # 2
71
+ if runnable .state is None :
72
+ # run_el should always return a coroutine
73
+ await self .worker .run_el (runnable , rerun = rerun )
74
+ # 3
85
75
else :
86
- await load_and_run_async (
87
- task_pkl = wf_main_pkl , ind = ind , submitter = self , rerun = rerun
88
- )
76
+ await self . expand_runnable ( runnable , wait = True )
77
+ return True
78
+
89
79
90
- async def submit (self , runnable , wait = False , rerun = False ):
80
+ async def expand_runnable (self , runnable , wait = False , rerun = False ):
91
81
"""
92
- Coroutine entrypoint for task submission .
82
+ This coroutine handles state expansion .
93
83
94
84
Removes any states from `runnable`. If `wait` is
95
85
set to False (default), aggregates all worker
@@ -110,41 +100,41 @@ async def submit(self, runnable, wait=False, rerun=False):
110
100
Coroutines for :class:`~pydra.engine.core.TaskBase` execution.
111
101
112
102
"""
103
+ if runnable .plugin and runnable .plugin != self .plugin :
104
+ raise NotImplementedError ()
105
+
113
106
futures = set ()
114
- if runnable .state :
115
- runnable .state .prepare_states (runnable .inputs , cont_dim = runnable .cont_dim )
116
- runnable .state .prepare_inputs ()
117
- logger .debug (
118
- f"Expanding { runnable } into { len (runnable .state .states_val )} states"
119
- )
120
- task_pkl = runnable .pickle_task ()
121
-
122
- for sidx in range (len (runnable .state .states_val )):
123
- job_tuple = (sidx , task_pkl , runnable )
124
- if is_workflow (runnable ):
125
- # job has no state anymore
126
- futures .add (self .submit_workflow (job_tuple , rerun = rerun ))
127
- else :
128
- # tasks are submitted to worker for execution
129
- futures .add (self .worker .run_el (job_tuple , rerun = rerun ))
130
- else :
107
+ if runnable .state is None :
108
+ raise Exception ("Only runnables with state should reach here" )
109
+
110
+ task_pkl = await prepare_runnable_with_state (runnable )
111
+
112
+ for sidx in range (len (runnable .state .states_val )):
131
113
if is_workflow (runnable ):
132
- await self ._run_workflow (runnable , rerun = rerun )
114
+ # job has no state anymore
115
+ futures .add (
116
+ # This unpickles and runs workflow - why are we pickling?
117
+ asyncio .create_task (load_and_run_async (task_pkl , sidx , self , rerun ))
118
+ )
133
119
else :
134
- # submit task to worker
135
- futures .add (self .worker .run_el (runnable , rerun = rerun ))
120
+ futures .add (
121
+ self .worker .run_el ((task_pkl , sidx , runnable ), rerun = rerun )
122
+ )
123
+
136
124
137
125
if wait and futures :
138
- # run coroutines concurrently and wait for execution
139
- # wait until all states complete or error
126
+ # if wait is True, we are at the end of the graph / state expansion.
127
+ # Once the remaining jobs end, we will exit `submit_from_call`
140
128
await asyncio .gather (* futures )
141
129
return
142
130
# pass along futures to be awaited independently
143
131
return futures
144
132
145
- async def _run_workflow (self , wf , rerun = False ):
133
+
134
+ async def expand_workflow (self , wf , rerun = False ):
146
135
"""
147
136
Expand and execute a stateless :class:`~pydra.engine.core.Workflow`.
137
+ This method is only reached by `Workflow._run_task`.
148
138
149
139
Parameters
150
140
----------
@@ -157,10 +147,6 @@ async def _run_workflow(self, wf, rerun=False):
157
147
The computed workflow
158
148
159
149
"""
160
- for nd in wf .graph .nodes :
161
- if nd .allow_cache_override :
162
- nd .cache_dir = wf .cache_dir
163
-
164
150
# creating a copy of the graph that will be modified
165
151
# the copy contains new lists with original runnable objects
166
152
graph_copy = wf .graph .copy ()
@@ -180,7 +166,8 @@ async def _run_workflow(self, wf, rerun=False):
180
166
while not tasks and graph_copy .nodes :
181
167
tasks , follow_err = get_runnable_tasks (graph_copy )
182
168
ii += 1
183
- time .sleep (1 )
169
+ # don't block the event loop!
170
+ await asyncio .sleep (1 )
184
171
if ii > 60 :
185
172
raise Exception (
186
173
"graph is not empty, but not able to get more tasks "
@@ -191,11 +178,15 @@ async def _run_workflow(self, wf, rerun=False):
191
178
logger .debug (f"Retrieving inputs for { task } " )
192
179
# TODO: add state idx to retrieve values to reduce waiting
193
180
task .inputs .retrieve_values (wf )
194
- if is_workflow (task ) and not task .state :
195
- await self .submit_workflow (task , rerun = rerun )
196
- else :
197
- for fut in await self .submit (task , rerun = rerun ):
181
+ if task .state :
182
+ for fut in await self .expand_runnable (task , rerun = rerun ):
198
183
task_futures .add (fut )
184
+ # expand that workflow
185
+ elif is_workflow (task ):
186
+ await task ._run (self , rerun = rerun )
187
+ # single task
188
+ else :
189
+ task_futures .add (self .worker .run_el (task , rerun = rerun ))
199
190
task_futures = await self .worker .fetch_finished (task_futures )
200
191
tasks , follow_err = get_runnable_tasks (graph_copy )
201
192
# updating tasks_errored
@@ -285,3 +276,13 @@ def is_runnable(graph, obj):
285
276
graph .remove_nodes_connections (nd )
286
277
287
278
return True
279
+
280
+
281
+
282
+ async def prepare_runnable_with_state (runnable ):
283
+ runnable .state .prepare_states (runnable .inputs , cont_dim = runnable .cont_dim )
284
+ runnable .state .prepare_inputs ()
285
+ logger .debug (
286
+ f"Expanding { runnable } into { len (runnable .state .states_val )} states"
287
+ )
288
+ return runnable .pickle_task ()
0 commit comments