Skip to content

Commit 6d23a36

Browse files
committed
RF: Clean up task/workflow submission
1 parent 0520af5 commit 6d23a36

File tree

3 files changed

+110
-87
lines changed

3 files changed

+110
-87
lines changed

pydra/engine/core.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,8 +1097,11 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
10971097
async def _run_task(self, submitter, rerun=False):
10981098
if not submitter:
10991099
raise Exception("Submitter should already be set.")
1100+
for nd in self.graph.nodes:
1101+
if nd.allow_cache_override:
1102+
nd.cache_dir = self.cache_dir
11001103
# at this point Workflow is stateless so this should be fine
1101-
await submitter._run_workflow(self, rerun=rerun)
1104+
await submitter.expand_workflow(self, rerun=rerun)
11021105

11031106
def set_output(self, connections):
11041107
"""
@@ -1227,21 +1230,31 @@ def create_dotfile(self, type="simple", export=None, name=None):
12271230
formatted_dot.append(self.graph.export_graph(dotfile=dotfile, ext=ext))
12281231
return dotfile, formatted_dot
12291232

1230-
def _connect_and_propagate_to_tasks(self):
1233+
def _connect_and_propagate_to_tasks(
1234+
self,
1235+
*,
1236+
propagate_rerun=False,
1237+
override_task_caches=False
1238+
):
12311239
"""
12321240
Visit each node in the graph and create the connections.
12331241
Additionally checks if all tasks should be rerun.
12341242
"""
12351243
for task in self.graph.nodes:
1244+
self.create_connections(task)
12361245
# if workflow has task_rerun=True and propagate_rerun=True,
12371246
# it should be passed to the tasks
12381247
if self.task_rerun and self.propagate_rerun:
12391248
task.task_rerun = self.task_rerun
12401249
# if the task is a wf, than the propagate_rerun should be also set
12411250
if is_workflow(task):
12421251
task.propagate_rerun = self.propagate_rerun
1252+
1253+
# ported from Submitter.__call__
1254+
# TODO: no prepare state ?
1255+
if override_task_caches and task.allow_cache_override:
1256+
task.cache_dir = self.cache_dir
12431257
task.cache_locations = task._cache_locations + self.cache_locations
1244-
self.create_connections(task)
12451258

12461259

12471260
def is_task(obj):

pydra/engine/submitter.py

Lines changed: 84 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,7 @@
22
import asyncio
33
import time
44
from uuid import uuid4
5-
from .workers import (
6-
SerialWorker,
7-
ConcurrentFuturesWorker,
8-
SlurmWorker,
9-
DaskWorker,
10-
SGEWorker,
11-
)
5+
from .workers import WORKERS
126
from .core import is_workflow
137
from .helpers import get_open_loop, load_and_run_async
148

@@ -35,61 +29,57 @@ def __init__(self, plugin="cf", **kwargs):
3529
self.loop = get_open_loop()
3630
self._own_loop = not self.loop.is_running()
3731
self.plugin = plugin
38-
if self.plugin == "serial":
39-
self.worker = SerialWorker()
40-
elif self.plugin == "cf":
41-
self.worker = ConcurrentFuturesWorker(**kwargs)
42-
elif self.plugin == "slurm":
43-
self.worker = SlurmWorker(**kwargs)
44-
elif self.plugin == "dask":
45-
self.worker = DaskWorker(**kwargs)
46-
elif self.plugin == "sge":
47-
self.worker = SGEWorker(**kwargs)
48-
else:
49-
raise Exception(f"plugin {self.plugin} not available")
32+
try:
33+
self.worker = WORKERS[self.plugin](**kwargs)
34+
except KeyError:
35+
raise NotImplementedError(f"No worker for {self.plugin}")
5036
self.worker.loop = self.loop
5137

5238
def __call__(self, runnable, cache_locations=None, rerun=False):
53-
"""Submit."""
39+
"""Submitter run function."""
5440
if cache_locations is not None:
5541
runnable.cache_locations = cache_locations
56-
# creating all connections and calculating the checksum of the graph before running
57-
if is_workflow(runnable):
58-
# TODO: no prepare state ?
59-
for nd in runnable.graph.nodes:
60-
runnable.create_connections(nd)
61-
if nd.allow_cache_override:
62-
nd.cache_dir = runnable.cache_dir
63-
if is_workflow(runnable) and runnable.state is None:
64-
self.loop.run_until_complete(self.submit_workflow(runnable, rerun=rerun))
65-
else:
66-
self.loop.run_until_complete(self.submit(runnable, wait=True, rerun=rerun))
67-
if is_workflow(runnable):
68-
# resetting all connections with LazyFields
69-
runnable._reset()
42+
self.loop.run_until_complete(self.submit_from_call(runnable, rerun))
7043
return runnable.result()
7144

72-
async def submit_workflow(self, workflow, rerun=False):
73-
"""Distribute or initiate workflow execution."""
74-
if is_workflow(workflow):
75-
if workflow.plugin and workflow.plugin != self.plugin:
76-
# dj: this is not tested!!! TODO
77-
await self.worker.run_el(workflow, rerun=rerun)
45+
async def submit_from_call(self, runnable, rerun):
46+
"""
47+
This coroutine should only be called once per Submitter call,
48+
and serves as the bridge between sync/async lands.
49+
50+
There are 3 potential paths based on the type of runnable:
51+
52+
1) Workflow without State
53+
2) Task without State
54+
3) (Workflow or Task) with State
55+
56+
Once Python 3.10 is the minimum, this should probably be refactored into using
57+
structural pattern matching.
58+
"""
59+
if is_workflow(runnable):
60+
# connect and calculate the checksum of the graph before running
61+
runnable._connect_and_propagate_to_tasks(override_task_caches=True)
62+
# 1
63+
if runnable.state is None:
64+
await runnable._run(self, rerun=rerun)
65+
# 3
7866
else:
79-
await workflow._run(self, rerun=rerun)
80-
else: # could be a tuple with paths to pickle files wiith tasks and inputs
81-
ind, wf_main_pkl, wf_orig = workflow
82-
if wf_orig.plugin and wf_orig.plugin != self.plugin:
83-
# dj: this is not tested!!! TODO
84-
await self.worker.run_el(workflow, rerun=rerun)
67+
await self.expand_runnable(runnable, wait=True)
68+
runnable._reset()
69+
else:
70+
# 2
71+
if runnable.state is None:
72+
# run_el should always return a coroutine
73+
await self.worker.run_el(runnable, rerun=rerun)
74+
# 3
8575
else:
86-
await load_and_run_async(
87-
task_pkl=wf_main_pkl, ind=ind, submitter=self, rerun=rerun
88-
)
76+
await self.expand_runnable(runnable, wait=True)
77+
return True
78+
8979

90-
async def submit(self, runnable, wait=False, rerun=False):
80+
async def expand_runnable(self, runnable, wait=False, rerun=False):
9181
"""
92-
Coroutine entrypoint for task submission.
82+
This coroutine handles state expansion.
9383
9484
Removes any states from `runnable`. If `wait` is
9585
set to False (default), aggregates all worker
@@ -110,41 +100,41 @@ async def submit(self, runnable, wait=False, rerun=False):
110100
Coroutines for :class:`~pydra.engine.core.TaskBase` execution.
111101
112102
"""
103+
if runnable.plugin and runnable.plugin != self.plugin:
104+
raise NotImplementedError()
105+
113106
futures = set()
114-
if runnable.state:
115-
runnable.state.prepare_states(runnable.inputs, cont_dim=runnable.cont_dim)
116-
runnable.state.prepare_inputs()
117-
logger.debug(
118-
f"Expanding {runnable} into {len(runnable.state.states_val)} states"
119-
)
120-
task_pkl = runnable.pickle_task()
121-
122-
for sidx in range(len(runnable.state.states_val)):
123-
job_tuple = (sidx, task_pkl, runnable)
124-
if is_workflow(runnable):
125-
# job has no state anymore
126-
futures.add(self.submit_workflow(job_tuple, rerun=rerun))
127-
else:
128-
# tasks are submitted to worker for execution
129-
futures.add(self.worker.run_el(job_tuple, rerun=rerun))
130-
else:
107+
if runnable.state is None:
108+
raise Exception("Only runnables with state should reach here")
109+
110+
task_pkl = await prepare_runnable_with_state(runnable)
111+
112+
for sidx in range(len(runnable.state.states_val)):
131113
if is_workflow(runnable):
132-
await self._run_workflow(runnable, rerun=rerun)
114+
# job has no state anymore
115+
futures.add(
116+
# This unpickles and runs workflow - why are we pickling?
117+
asyncio.create_task(load_and_run_async(task_pkl, sidx, self, rerun))
118+
)
133119
else:
134-
# submit task to worker
135-
futures.add(self.worker.run_el(runnable, rerun=rerun))
120+
futures.add(
121+
self.worker.run_el((task_pkl, sidx, runnable), rerun=rerun)
122+
)
123+
136124

137125
if wait and futures:
138-
# run coroutines concurrently and wait for execution
139-
# wait until all states complete or error
126+
# if wait is True, we are at the end of the graph / state expansion.
127+
# Once the remaining jobs end, we will exit `submit_from_call`
140128
await asyncio.gather(*futures)
141129
return
142130
# pass along futures to be awaited independently
143131
return futures
144132

145-
async def _run_workflow(self, wf, rerun=False):
133+
134+
async def expand_workflow(self, wf, rerun=False):
146135
"""
147136
Expand and execute a stateless :class:`~pydra.engine.core.Workflow`.
137+
This method is only reached by `Workflow._run_task`.
148138
149139
Parameters
150140
----------
@@ -157,10 +147,6 @@ async def _run_workflow(self, wf, rerun=False):
157147
The computed workflow
158148
159149
"""
160-
for nd in wf.graph.nodes:
161-
if nd.allow_cache_override:
162-
nd.cache_dir = wf.cache_dir
163-
164150
# creating a copy of the graph that will be modified
165151
# the copy contains new lists with original runnable objects
166152
graph_copy = wf.graph.copy()
@@ -180,7 +166,8 @@ async def _run_workflow(self, wf, rerun=False):
180166
while not tasks and graph_copy.nodes:
181167
tasks, follow_err = get_runnable_tasks(graph_copy)
182168
ii += 1
183-
time.sleep(1)
169+
# don't block the event loop!
170+
await asyncio.sleep(1)
184171
if ii > 60:
185172
raise Exception(
186173
"graph is not empty, but not able to get more tasks "
@@ -191,11 +178,15 @@ async def _run_workflow(self, wf, rerun=False):
191178
logger.debug(f"Retrieving inputs for {task}")
192179
# TODO: add state idx to retrieve values to reduce waiting
193180
task.inputs.retrieve_values(wf)
194-
if is_workflow(task) and not task.state:
195-
await self.submit_workflow(task, rerun=rerun)
196-
else:
197-
for fut in await self.submit(task, rerun=rerun):
181+
if task.state:
182+
for fut in await self.expand_runnable(task, rerun=rerun):
198183
task_futures.add(fut)
184+
# expand that workflow
185+
elif is_workflow(task):
186+
await task._run(self, rerun=rerun)
187+
# single task
188+
else:
189+
task_futures.add(self.worker.run_el(task, rerun=rerun))
199190
task_futures = await self.worker.fetch_finished(task_futures)
200191
tasks, follow_err = get_runnable_tasks(graph_copy)
201192
# updating tasks_errored
@@ -285,3 +276,13 @@ def is_runnable(graph, obj):
285276
graph.remove_nodes_connections(nd)
286277

287278
return True
279+
280+
281+
282+
async def prepare_runnable_with_state(runnable):
283+
runnable.state.prepare_states(runnable.inputs, cont_dim=runnable.cont_dim)
284+
runnable.state.prepare_inputs()
285+
logger.debug(
286+
f"Expanding {runnable} into {len(runnable.state.states_val)} states"
287+
)
288+
return runnable.pickle_task()

pydra/engine/workers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def fetch_finished(self, futures):
119119
class SerialWorker(Worker):
120120
"""A worker to execute linearly."""
121121

122-
def __init__(self):
122+
def __init__(self, **kwargs):
123123
"""Initialize worker."""
124124
logger.debug("Initialize SerialWorker")
125125

@@ -876,3 +876,12 @@ async def exec_dask(self, runnable, rerun=False):
876876
def close(self):
877877
"""Finalize the internal pool of tasks."""
878878
pass
879+
880+
881+
WORKERS = {
882+
"serial": SerialWorker,
883+
"cf": ConcurrentFuturesWorker,
884+
"slurm": SlurmWorker,
885+
"dask": DaskWorker,
886+
"sge": SGEWorker
887+
}

0 commit comments

Comments
 (0)