Skip to content

Commit 00d6ea4

Browse files
authored
Merge pull request #388 from djarecka/fix/call
fixing __call__ function
2 parents 5306585 + bf381f3 commit 00d6ea4

File tree

5 files changed

+102
-93
lines changed

5 files changed

+102
-93
lines changed

pydra/engine/core.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -392,26 +392,28 @@ def output_dir(self):
392392
return [self._cache_dir / checksum for checksum in self.checksum_states()]
393393
return self._cache_dir / self.checksum
394394

395-
def __call__(self, submitter=None, plugin=None, rerun=False, **kwargs):
395+
def __call__(
396+
self, submitter=None, plugin=None, plugin_kwargs=None, rerun=False, **kwargs
397+
):
396398
"""Make tasks callable themselves."""
397399
from .submitter import Submitter
398400

399401
if submitter and plugin:
400402
raise Exception("Specify submitter OR plugin, not both")
401-
plugin = plugin or self.plugin
402-
if plugin:
403-
submitter = Submitter(plugin=plugin)
404-
elif self.state:
405-
submitter = Submitter()
403+
elif submitter:
404+
pass
405+
# if there is plugin provided or the task is a Workflow or has a state,
406+
# the submitter will be created using provided plugin, self.plugin or "cf"
407+
elif plugin or self.state or is_workflow(self):
408+
plugin = plugin or self.plugin or "cf"
409+
if plugin_kwargs is None:
410+
plugin_kwargs = {}
411+
submitter = Submitter(plugin=plugin, **plugin_kwargs)
406412

407413
if submitter:
408414
with submitter as sub:
409415
res = sub(self)
410-
else:
411-
if is_workflow(self):
412-
raise NotImplementedError(
413-
"TODO: linear workflow execution - assign submitter or plugin for now"
414-
)
416+
else: # tasks without state could be run without a submitter
415417
res = self._run(rerun=rerun, **kwargs)
416418
return res
417419

pydra/engine/tests/test_numpy_examples.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,15 @@ def arrayout(val):
1717
return np.array([val, val])
1818

1919

20-
def test_multiout(plugin, tmpdir):
20+
def test_multiout(tmpdir):
2121
""" testing a simple function that returns a numpy array"""
2222
wf = Workflow("wf", input_spec=["val"], val=2)
2323
wf.add(arrayout(name="mo", val=wf.lzin.val))
2424

2525
wf.set_output([("array", wf.mo.lzout.b)])
2626
wf.cache_dir = tmpdir
2727

28-
with Submitter(plugin=plugin, n_procs=2) as sub:
28+
with Submitter(plugin="cf", n_procs=2) as sub:
2929
sub(runnable=wf)
3030

3131
results = wf.result(return_inputs=True)
@@ -34,7 +34,7 @@ def test_multiout(plugin, tmpdir):
3434
assert np.array_equal(results[1].output.array, np.array([2, 2]))
3535

3636

37-
def test_multiout_st(plugin, tmpdir):
37+
def test_multiout_st(tmpdir):
3838
""" testing a simple function that returns a numpy array, adding splitter"""
3939
wf = Workflow("wf", input_spec=["val"], val=[0, 1, 2])
4040
wf.add(arrayout(name="mo", val=wf.lzin.val))
@@ -43,7 +43,7 @@ def test_multiout_st(plugin, tmpdir):
4343
wf.set_output([("array", wf.mo.lzout.b)])
4444
wf.cache_dir = tmpdir
4545

46-
with Submitter(plugin=plugin, n_procs=2) as sub:
46+
with Submitter(plugin="cf", n_procs=2) as sub:
4747
sub(runnable=wf)
4848

4949
results = wf.result(return_inputs=True)

pydra/engine/tests/test_submitter.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,30 @@ def sleep_add_one(x):
2323

2424
def test_callable_wf(plugin, tmpdir):
2525
wf = gen_basic_wf()
26+
res = wf()
27+
assert res.output.out == 9
28+
del wf, res
2629

27-
with pytest.raises(NotImplementedError):
28-
wf()
29-
30+
# providing plugin
31+
wf = gen_basic_wf()
3032
res = wf(plugin="cf")
3133
assert res.output.out == 9
3234
del wf, res
3335

36+
# providing plugin_kwargs
3437
wf = gen_basic_wf()
35-
wf.cache_dir = tmpdir
38+
res = wf(plugin="cf", plugin_kwargs={"n_procs": 2})
39+
assert res.output.out == 9
40+
del wf, res
41+
42+
# providing wrong plugin_kwargs
43+
wf = gen_basic_wf()
44+
with pytest.raises(TypeError, match="an unexpected keyword argument"):
45+
wf(plugin="cf", plugin_kwargs={"sbatch_args": "-N2"})
3646

47+
# providing submitter
48+
wf = gen_basic_wf()
49+
wf.cache_dir = tmpdir
3750
sub = Submitter(plugin)
3851
res = wf(submitter=sub)
3952
assert res.output.out == 9

0 commit comments

Comments
 (0)