Skip to content

Commit 7e8232c

Browse files
authored
Merge pull request #524 from mgxd/rf/submitter
RF: Submitter logic.
2 parents 3a3e4bb + a2cc367 commit 7e8232c

File tree

3 files changed

+112
-92
lines changed

3 files changed

+112
-92
lines changed

pydra/engine/core.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,7 +1055,9 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
10551055
"Workflow output cannot be None, use set_output to define output(s)"
10561056
)
10571057
# creating connections that were defined after adding tasks to the wf
1058-
self._connect_and_propagate_to_tasks()
1058+
self._connect_and_propagate_to_tasks(
1059+
propagate_rerun=self.task_rerun and self.propagate_rerun
1060+
)
10591061

10601062
checksum = self.checksum
10611063
output_dir = self.output_dir
@@ -1097,8 +1099,11 @@ async def _run(self, submitter=None, rerun=False, **kwargs):
10971099
async def _run_task(self, submitter, rerun=False):
10981100
if not submitter:
10991101
raise Exception("Submitter should already be set.")
1102+
for nd in self.graph.nodes:
1103+
if nd.allow_cache_override:
1104+
nd.cache_dir = self.cache_dir
11001105
# at this point Workflow is stateless so this should be fine
1101-
await submitter._run_workflow(self, rerun=rerun)
1106+
await submitter.expand_workflow(self, rerun=rerun)
11021107

11031108
def set_output(self, connections):
11041109
"""
@@ -1227,21 +1232,31 @@ def create_dotfile(self, type="simple", export=None, name=None):
12271232
formatted_dot.append(self.graph.export_graph(dotfile=dotfile, ext=ext))
12281233
return dotfile, formatted_dot
12291234

1230-
def _connect_and_propagate_to_tasks(self):
1235+
def _connect_and_propagate_to_tasks(
1236+
self,
1237+
*,
1238+
propagate_rerun=False,
1239+
override_task_caches=False,
1240+
):
12311241
"""
12321242
Visit each node in the graph and create the connections.
12331243
Additionally checks if all tasks should be rerun.
12341244
"""
12351245
for task in self.graph.nodes:
1246+
self.create_connections(task)
12361247
# if workflow has task_rerun=True and propagate_rerun=True,
12371248
# it should be passed to the tasks
1238-
if self.task_rerun and self.propagate_rerun:
1239-
task.task_rerun = self.task_rerun
1249+
if propagate_rerun:
1250+
task.task_rerun = True
12401251
# if the task is a wf, than the propagate_rerun should be also set
12411252
if is_workflow(task):
1242-
task.propagate_rerun = self.propagate_rerun
1253+
task.propagate_rerun = True
1254+
1255+
# ported from Submitter.__call__
1256+
# TODO: no prepare state ?
1257+
if override_task_caches and task.allow_cache_override:
1258+
task.cache_dir = self.cache_dir
12431259
task.cache_locations = task._cache_locations + self.cache_locations
1244-
self.create_connections(task)
12451260

12461261

12471262
def is_task(obj):

pydra/engine/submitter.py

Lines changed: 80 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
"""Handle execution backends."""
22
import asyncio
3-
import time
43
from uuid import uuid4
5-
from .workers import (
6-
SerialWorker,
7-
ConcurrentFuturesWorker,
8-
SlurmWorker,
9-
DaskWorker,
10-
SGEWorker,
11-
)
4+
from .workers import WORKERS
125
from .core import is_workflow
136
from .helpers import get_open_loop, load_and_run_async
147

@@ -35,61 +28,60 @@ def __init__(self, plugin="cf", **kwargs):
3528
self.loop = get_open_loop()
3629
self._own_loop = not self.loop.is_running()
3730
self.plugin = plugin
38-
if self.plugin == "serial":
39-
self.worker = SerialWorker()
40-
elif self.plugin == "cf":
41-
self.worker = ConcurrentFuturesWorker(**kwargs)
42-
elif self.plugin == "slurm":
43-
self.worker = SlurmWorker(**kwargs)
44-
elif self.plugin == "dask":
45-
self.worker = DaskWorker(**kwargs)
46-
elif self.plugin == "sge":
47-
self.worker = SGEWorker(**kwargs)
48-
else:
49-
raise Exception(f"plugin {self.plugin} not available")
31+
try:
32+
self.worker = WORKERS[self.plugin](**kwargs)
33+
except KeyError:
34+
raise NotImplementedError(f"No worker for {self.plugin}")
5035
self.worker.loop = self.loop
5136

5237
def __call__(self, runnable, cache_locations=None, rerun=False):
53-
"""Submit."""
38+
"""Submitter run function."""
5439
if cache_locations is not None:
5540
runnable.cache_locations = cache_locations
56-
# creating all connections and calculating the checksum of the graph before running
57-
if is_workflow(runnable):
58-
# TODO: no prepare state ?
59-
for nd in runnable.graph.nodes:
60-
runnable.create_connections(nd)
61-
if nd.allow_cache_override:
62-
nd.cache_dir = runnable.cache_dir
63-
if is_workflow(runnable) and runnable.state is None:
64-
self.loop.run_until_complete(self.submit_workflow(runnable, rerun=rerun))
65-
else:
66-
self.loop.run_until_complete(self.submit(runnable, wait=True, rerun=rerun))
67-
if is_workflow(runnable):
68-
# resetting all connections with LazyFields
69-
runnable._reset()
41+
self.loop.run_until_complete(self.submit_from_call(runnable, rerun))
7042
return runnable.result()
7143

72-
async def submit_workflow(self, workflow, rerun=False):
73-
"""Distribute or initiate workflow execution."""
74-
if is_workflow(workflow):
75-
if workflow.plugin and workflow.plugin != self.plugin:
76-
# dj: this is not tested!!! TODO
77-
await self.worker.run_el(workflow, rerun=rerun)
44+
async def submit_from_call(self, runnable, rerun):
45+
"""
46+
This coroutine should only be called once per Submitter call,
47+
and serves as the bridge between sync/async lands.
48+
49+
There are 4 potential paths based on the type of runnable:
50+
0) Workflow has a different plugin than a submitter
51+
1) Workflow without State
52+
2) Task without State
53+
3) (Workflow or Task) with State
54+
55+
Once Python 3.10 is the minimum, this should probably be refactored into using
56+
structural pattern matching.
57+
"""
58+
if is_workflow(runnable):
59+
# connect and calculate the checksum of the graph before running
60+
runnable._connect_and_propagate_to_tasks(override_task_caches=True)
61+
# 0
62+
if runnable.plugin and runnable.plugin != self.plugin:
63+
# if workflow has a different plugin it's treated as a single element
64+
await self.worker.run_el(runnable, rerun=rerun)
65+
# 1
66+
if runnable.state is None:
67+
await runnable._run(self, rerun=rerun)
68+
# 3
7869
else:
79-
await workflow._run(self, rerun=rerun)
80-
else: # could be a tuple with paths to pickle files wiith tasks and inputs
81-
ind, wf_main_pkl, wf_orig = workflow
82-
if wf_orig.plugin and wf_orig.plugin != self.plugin:
83-
# dj: this is not tested!!! TODO
84-
await self.worker.run_el(workflow, rerun=rerun)
70+
await self.expand_runnable(runnable, wait=True, rerun=rerun)
71+
runnable._reset()
72+
else:
73+
# 2
74+
if runnable.state is None:
75+
# run_el should always return a coroutine
76+
await self.worker.run_el(runnable, rerun=rerun)
77+
# 3
8578
else:
86-
await load_and_run_async(
87-
task_pkl=wf_main_pkl, ind=ind, submitter=self, rerun=rerun
88-
)
79+
await self.expand_runnable(runnable, wait=True, rerun=rerun)
80+
return True
8981

90-
async def submit(self, runnable, wait=False, rerun=False):
82+
async def expand_runnable(self, runnable, wait=False, rerun=False):
9183
"""
92-
Coroutine entrypoint for task submission.
84+
This coroutine handles state expansion.
9385
9486
Removes any states from `runnable`. If `wait` is
9587
set to False (default), aggregates all worker
@@ -110,41 +102,37 @@ async def submit(self, runnable, wait=False, rerun=False):
110102
Coroutines for :class:`~pydra.engine.core.TaskBase` execution.
111103
112104
"""
105+
if runnable.plugin and runnable.plugin != self.plugin:
106+
raise NotImplementedError()
107+
113108
futures = set()
114-
if runnable.state:
115-
runnable.state.prepare_states(runnable.inputs, cont_dim=runnable.cont_dim)
116-
runnable.state.prepare_inputs()
117-
logger.debug(
118-
f"Expanding {runnable} into {len(runnable.state.states_val)} states"
119-
)
120-
task_pkl = runnable.pickle_task()
121-
122-
for sidx in range(len(runnable.state.states_val)):
123-
job_tuple = (sidx, task_pkl, runnable)
124-
if is_workflow(runnable):
125-
# job has no state anymore
126-
futures.add(self.submit_workflow(job_tuple, rerun=rerun))
127-
else:
128-
# tasks are submitted to worker for execution
129-
futures.add(self.worker.run_el(job_tuple, rerun=rerun))
130-
else:
109+
if runnable.state is None:
110+
raise Exception("Only runnables with state should reach here")
111+
112+
task_pkl = await prepare_runnable_with_state(runnable)
113+
114+
for sidx in range(len(runnable.state.states_val)):
131115
if is_workflow(runnable):
132-
await self._run_workflow(runnable, rerun=rerun)
116+
# job has no state anymore
117+
futures.add(
118+
# This unpickles and runs workflow - why are we pickling?
119+
asyncio.create_task(load_and_run_async(task_pkl, sidx, self, rerun))
120+
)
133121
else:
134-
# submit task to worker
135-
futures.add(self.worker.run_el(runnable, rerun=rerun))
122+
futures.add(self.worker.run_el((sidx, task_pkl, runnable), rerun=rerun))
136123

137124
if wait and futures:
138-
# run coroutines concurrently and wait for execution
139-
# wait until all states complete or error
125+
# if wait is True, we are at the end of the graph / state expansion.
126+
# Once the remaining jobs end, we will exit `submit_from_call`
140127
await asyncio.gather(*futures)
141128
return
142129
# pass along futures to be awaited independently
143130
return futures
144131

145-
async def _run_workflow(self, wf, rerun=False):
132+
async def expand_workflow(self, wf, rerun=False):
146133
"""
147134
Expand and execute a stateless :class:`~pydra.engine.core.Workflow`.
135+
This method is only reached by `Workflow._run_task`.
148136
149137
Parameters
150138
----------
@@ -157,10 +145,6 @@ async def _run_workflow(self, wf, rerun=False):
157145
The computed workflow
158146
159147
"""
160-
for nd in wf.graph.nodes:
161-
if nd.allow_cache_override:
162-
nd.cache_dir = wf.cache_dir
163-
164148
# creating a copy of the graph that will be modified
165149
# the copy contains new lists with original runnable objects
166150
graph_copy = wf.graph.copy()
@@ -180,7 +164,8 @@ async def _run_workflow(self, wf, rerun=False):
180164
while not tasks and graph_copy.nodes:
181165
tasks, follow_err = get_runnable_tasks(graph_copy)
182166
ii += 1
183-
time.sleep(1)
167+
# don't block the event loop!
168+
await asyncio.sleep(1)
184169
if ii > 60:
185170
raise Exception(
186171
"graph is not empty, but not able to get more tasks "
@@ -191,11 +176,15 @@ async def _run_workflow(self, wf, rerun=False):
191176
logger.debug(f"Retrieving inputs for {task}")
192177
# TODO: add state idx to retrieve values to reduce waiting
193178
task.inputs.retrieve_values(wf)
194-
if is_workflow(task) and not task.state:
195-
await self.submit_workflow(task, rerun=rerun)
196-
else:
197-
for fut in await self.submit(task, rerun=rerun):
179+
if task.state:
180+
for fut in await self.expand_runnable(task, rerun=rerun):
198181
task_futures.add(fut)
182+
# expand that workflow
183+
elif is_workflow(task):
184+
await task._run(self, rerun=rerun)
185+
# single task
186+
else:
187+
task_futures.add(self.worker.run_el(task, rerun=rerun))
199188
task_futures = await self.worker.fetch_finished(task_futures)
200189
tasks, follow_err = get_runnable_tasks(graph_copy)
201190
# updating tasks_errored
@@ -285,3 +274,10 @@ def is_runnable(graph, obj):
285274
graph.remove_nodes_connections(nd)
286275

287276
return True
277+
278+
279+
async def prepare_runnable_with_state(runnable):
280+
runnable.state.prepare_states(runnable.inputs, cont_dim=runnable.cont_dim)
281+
runnable.state.prepare_inputs()
282+
logger.debug(f"Expanding {runnable} into {len(runnable.state.states_val)} states")
283+
return runnable.pickle_task()

pydra/engine/workers.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def fetch_finished(self, futures):
119119
class SerialWorker(Worker):
120120
"""A worker to execute linearly."""
121121

122-
def __init__(self):
122+
def __init__(self, **kwargs):
123123
"""Initialize worker."""
124124
logger.debug("Initialize SerialWorker")
125125

@@ -876,3 +876,12 @@ async def exec_dask(self, runnable, rerun=False):
876876
def close(self):
877877
"""Finalize the internal pool of tasks."""
878878
pass
879+
880+
881+
WORKERS = {
882+
"serial": SerialWorker,
883+
"cf": ConcurrentFuturesWorker,
884+
"slurm": SlurmWorker,
885+
"dask": DaskWorker,
886+
"sge": SGEWorker,
887+
}

0 commit comments

Comments
 (0)