Skip to content

Commit 4ef5273

Browse files
authored
Merge pull request #13 from djarecka/satra-enh-dask
allowing for limitted dask testing
2 parents 28f8f64 + 97a3f3b commit 4ef5273

19 files changed

+860
-462
lines changed

.travis.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ env:
3333
# Useful for testing un-released upstream fixes
3434
matrix:
3535
include:
36+
- python: 3.7
37+
env:
38+
- INSTALL_TYPE="develop"
39+
- CHECK_TYPE="test_dask"
3640
- os: osx
3741
osx_image: xcode11.2
3842
language: generic

ci/none.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ function travis_before_script {
4040
# extras don't seem possible with setup.py install, so switch to pip
4141
pip install ".[test]"
4242
fi
43+
elif [ "$CHECK_TYPE" = "test_dask" ]; then
44+
if [ "$INSTALL_TYPE" = "develop" ]; then
45+
pip install -e ".[dask]"
46+
fi
4347
elif [ "$CHECK_TYPE" = "style" ]; then
4448
pip install black==19.3b0
4549
fi
@@ -48,6 +52,8 @@ function travis_before_script {
4852
function travis_script {
4953
if [ "$CHECK_TYPE" = "test" ]; then
5054
pytest -vs -n auto --cov pydra --cov-config .coveragerc --cov-report xml:cov.xml --doctest-modules pydra
55+
elif [ "$CHECK_TYPE" = "test_dask" ]; then
56+
pytest -vs -n auto --cov pydra --cov-config .coveragerc --cov-report xml:cov.xml --doctest-modules --dask pydra/engine
5157
elif [ "$CHECK_TYPE" = "style" ]; then
5258
black --check pydra tools setup.py
5359
fi

pydra/conftest.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import shutil
2+
3+
4+
def pytest_addoption(parser):
5+
parser.addoption("--dask", action="store_true", help="run all combinations")
6+
7+
8+
def pytest_generate_tests(metafunc):
9+
if "plugin_dask_opt" in metafunc.fixturenames:
10+
if bool(shutil.which("sbatch")):
11+
Plugins = ["cf", "slurm"]
12+
else:
13+
Plugins = ["cf"]
14+
if metafunc.config.getoption("dask"):
15+
Plugins.append("dask")
16+
metafunc.parametrize("plugin_dask_opt", Plugins)
17+
18+
if "plugin" in metafunc.fixturenames:
19+
if metafunc.config.getoption("dask"):
20+
Plugins = []
21+
elif bool(shutil.which("sbatch")):
22+
Plugins = ["cf", "slurm"]
23+
else:
24+
Plugins = ["cf"]
25+
metafunc.parametrize("plugin", Plugins)

pydra/engine/boutiques.py

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import typing as ty
2+
import json
3+
import attr
4+
from urllib.request import urlretrieve
5+
from pathlib import Path
6+
from functools import reduce
7+
8+
from ..utils.messenger import AuditFlag
9+
from ..engine import ShellCommandTask
10+
from ..engine.specs import SpecInfo, ShellSpec, ShellOutSpec, File, attr_fields
11+
from .helpers_file import is_local_file
12+
13+
14+
class BoshTask(ShellCommandTask):
15+
"""Shell Command Task based on the Boutiques descriptor"""
16+
17+
def __init__(
18+
self,
19+
zenodo_id=None,
20+
bosh_file=None,
21+
audit_flags: AuditFlag = AuditFlag.NONE,
22+
cache_dir=None,
23+
input_spec_names: ty.Optional[ty.List] = None,
24+
messenger_args=None,
25+
messengers=None,
26+
name=None,
27+
output_spec_names: ty.Optional[ty.List] = None,
28+
rerun=False,
29+
strip=False,
30+
**kwargs,
31+
):
32+
"""
33+
Initialize this task.
34+
35+
Parameters
36+
----------
37+
zenodo_id: :obj: str
38+
Zenodo ID
39+
bosh_file : : str
40+
json file with the boutiques descriptors
41+
audit_flags : :obj:`pydra.utils.messenger.AuditFlag`
42+
Auditing configuration
43+
cache_dir : :obj:`os.pathlike`
44+
Cache directory
45+
input_spec_names : :obj: list
46+
Input names for input_spec.
47+
messenger_args :
48+
TODO
49+
messengers :
50+
TODO
51+
name : :obj:`str`
52+
Name of this task.
53+
output_spec_names : :obj: list
54+
Output names for output_spec.
55+
strip : :obj:`bool`
56+
TODO
57+
58+
"""
59+
self.cache_dir = cache_dir
60+
if (bosh_file and zenodo_id) or not (bosh_file or zenodo_id):
61+
raise Exception("either bosh or zenodo_id has to be specified")
62+
elif zenodo_id:
63+
self.bosh_file = self._download_spec(zenodo_id)
64+
else: # bosh_file
65+
self.bosh_file = bosh_file
66+
67+
with self.bosh_file.open() as f:
68+
self.bosh_spec = json.load(f)
69+
70+
self.input_spec = self._prepare_input_spec(names_subset=input_spec_names)
71+
self.output_spec = self._prepare_output_spec(names_subset=output_spec_names)
72+
self.bindings = ["-v", f"{self.bosh_file.parent}:{self.bosh_file.parent}:ro"]
73+
74+
super(BoshTask, self).__init__(
75+
name=name,
76+
input_spec=self.input_spec,
77+
output_spec=self.output_spec,
78+
executable=["bosh", "exec", "launch"],
79+
args=["-s"],
80+
audit_flags=audit_flags,
81+
messengers=messengers,
82+
messenger_args=messenger_args,
83+
cache_dir=self.cache_dir,
84+
strip=strip,
85+
rerun=rerun,
86+
**kwargs,
87+
)
88+
self.strip = strip
89+
90+
def _download_spec(self, zenodo_id):
91+
"""
92+
usind boutiques Searcher to find url of zenodo file for a specific id,
93+
and download the file to self.cache_dir
94+
"""
95+
from boutiques.searcher import Searcher
96+
97+
searcher = Searcher(zenodo_id, exact_match=True)
98+
hits = searcher.zenodo_search().json()["hits"]["hits"]
99+
if len(hits) == 0:
100+
raise Exception(f"can't find zenodo spec for {zenodo_id}")
101+
elif len(hits) > 1:
102+
raise Exception(f"too many hits for {zenodo_id}")
103+
else:
104+
zenodo_url = hits[0]["files"][0]["links"]["self"]
105+
zenodo_file = self.cache_dir / f"zenodo.{zenodo_id}.json"
106+
urlretrieve(zenodo_url, zenodo_file)
107+
return zenodo_file
108+
109+
def _prepare_input_spec(self, names_subset=None):
110+
""" creating input spec from the zenodo file
111+
if name_subset provided, only names from the subset will be used in the spec
112+
"""
113+
binputs = self.bosh_spec["inputs"]
114+
self._input_spec_keys = {}
115+
fields = []
116+
for input in binputs:
117+
name = input["id"]
118+
if names_subset is None:
119+
pass
120+
elif name not in names_subset:
121+
continue
122+
else:
123+
names_subset.remove(name)
124+
if input["type"] == "File":
125+
tp = File
126+
elif input["type"] == "String":
127+
tp = str
128+
elif input["type"] == "Number":
129+
tp = float
130+
elif input["type"] == "Flag":
131+
tp = bool
132+
else:
133+
tp = None
134+
# adding list
135+
if tp and "list" in input and input["list"]:
136+
tp = ty.List[tp]
137+
138+
mdata = {
139+
"help_string": input.get("description", None) or input["name"],
140+
"mandatory": not input["optional"],
141+
"argstr": input.get("command-line-flag", None),
142+
}
143+
fields.append((name, tp, mdata))
144+
self._input_spec_keys[input["value-key"]] = "{" + f"{name}" + "}"
145+
if names_subset:
146+
raise RuntimeError(f"{names_subset} are not in the zenodo input spec")
147+
spec = SpecInfo(name="Inputs", fields=fields, bases=(ShellSpec,))
148+
return spec
149+
150+
def _prepare_output_spec(self, names_subset=None):
151+
""" creating output spec from the zenodo file
152+
if name_subset provided, only names from the subset will be used in the spec
153+
"""
154+
boutputs = self.bosh_spec["output-files"]
155+
fields = []
156+
for output in boutputs:
157+
name = output["id"]
158+
if names_subset is None:
159+
pass
160+
elif name not in names_subset:
161+
continue
162+
else:
163+
names_subset.remove(name)
164+
path_template = reduce(
165+
lambda s, r: s.replace(*r),
166+
self._input_spec_keys.items(),
167+
output["path-template"],
168+
)
169+
mdata = {
170+
"help_string": output.get("description", None) or output["name"],
171+
"mandatory": not output["optional"],
172+
"output_file_template": path_template,
173+
}
174+
fields.append((name, attr.ib(type=File, metadata=mdata)))
175+
176+
if names_subset:
177+
raise RuntimeError(f"{names_subset} are not in the zenodo output spec")
178+
spec = SpecInfo(name="Outputs", fields=fields, bases=(ShellOutSpec,))
179+
return spec
180+
181+
def _command_args_single(self, state_ind, ind=None):
182+
"""Get command line arguments for a single state"""
183+
input_filepath = self._bosh_invocation_file(state_ind=state_ind, ind=ind)
184+
cmd_list = (
185+
self.inputs.executable
186+
+ [str(self.bosh_file), input_filepath]
187+
+ self.inputs.args
188+
+ self.bindings
189+
)
190+
return cmd_list
191+
192+
def _bosh_invocation_file(self, state_ind, ind=None):
193+
"""creating bosh invocation file - json file with inputs values"""
194+
input_json = {}
195+
for f in attr_fields(self.inputs):
196+
if f.name in ["executable", "args"]:
197+
continue
198+
if self.state and f"{self.name}.{f.name}" in state_ind:
199+
value = getattr(self.inputs, f.name)[state_ind[f"{self.name}.{f.name}"]]
200+
else:
201+
value = getattr(self.inputs, f.name)
202+
# adding to the json file if specified by the user
203+
if value is not attr.NOTHING and value != "NOTHING":
204+
if is_local_file(f):
205+
value = Path(value)
206+
self.bindings.extend(["-v", f"{value.parent}:{value.parent}:ro"])
207+
value = str(value)
208+
209+
input_json[f.name] = value
210+
211+
filename = self.cache_dir / f"{self.name}-{ind}.json"
212+
with open(filename, "w") as jsonfile:
213+
json.dump(input_json, jsonfile)
214+
215+
return str(filename)

pydra/engine/core.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -532,29 +532,37 @@ def done(self):
532532
return True
533533
return False
534534

535-
def _combined_output(self):
535+
def _combined_output(self, return_inputs=False):
536536
combined_results = []
537537
for (gr, ind_l) in self.state.final_combined_ind_mapping.items():
538-
combined_results.append([])
538+
combined_results_gr = []
539539
for ind in ind_l:
540540
result = load_result(self.checksum_states(ind), self.cache_locations)
541541
if result is None:
542542
return None
543-
combined_results[gr].append(result)
543+
if return_inputs is True or return_inputs == "val":
544+
result = (self.state.states_val[ind], result)
545+
elif return_inputs == "ind":
546+
result = (self.state.states_ind[ind], result)
547+
combined_results_gr.append(result)
548+
combined_results.append(combined_results_gr)
544549
if len(combined_results) == 1 and self.state.splitter_rpn_final == []:
545550
# in case it's full combiner, removing the nested structure
546551
return combined_results[0]
547552
else:
548553
return combined_results
549554

550-
def result(self, state_index=None):
555+
def result(self, state_index=None, return_inputs=False):
551556
"""
552557
Retrieve the outcomes of this particular task.
553558
554559
Parameters
555560
----------
556-
state_index :
557-
TODO
561+
state_index : :obj: `int`
562+
index of the element for task with splitter and multiple states
563+
return_inputs : :obj: `bool`, :obj:`str`
564+
if True or "val" result is returned together with values of the input fields,
565+
if "ind" result is returned together with indices of the input fields
558566
559567
Returns
560568
-------
@@ -567,28 +575,50 @@ def result(self, state_index=None):
567575
if state_index is None:
568576
# if state_index=None, collecting all results
569577
if self.state.combiner:
570-
return self._combined_output()
578+
return self._combined_output(return_inputs=return_inputs)
571579
else:
572580
results = []
573581
for checksum in self.checksum_states():
574582
result = load_result(checksum, self.cache_locations)
575583
if result is None:
576584
return None
577585
results.append(result)
578-
return results
586+
if return_inputs is True or return_inputs == "val":
587+
return list(zip(self.state.states_val, results))
588+
elif return_inputs == "ind":
589+
return list(zip(self.state.states_ind, results))
590+
else:
591+
return results
579592
else: # state_index is not None
580593
if self.state.combiner:
581-
return self._combined_output()[state_index]
594+
return self._combined_output(return_inputs=return_inputs)[
595+
state_index
596+
]
582597
result = load_result(
583598
self.checksum_states(state_index), self.cache_locations
584599
)
585-
return result
600+
if return_inputs is True or return_inputs == "val":
601+
return (self.state.states_val[state_index], result)
602+
elif return_inputs == "ind":
603+
return (self.state.states_ind[state_index], result)
604+
else:
605+
return result
586606
else:
587607
if state_index is not None:
588608
raise ValueError("Task does not have a state")
589609
checksum = self.checksum
590610
result = load_result(checksum, self.cache_locations)
591-
return result
611+
if return_inputs is True or return_inputs == "val":
612+
inputs_val = {
613+
f"{self.name}.{inp}": getattr(self.inputs, inp)
614+
for inp in self.input_names
615+
}
616+
return (inputs_val, result)
617+
elif return_inputs == "ind":
618+
inputs_ind = {f"{self.name}.{inp}": None for inp in self.input_names}
619+
return (inputs_ind, result)
620+
else:
621+
return result
592622

593623
def _reset(self):
594624
"""Reset the connections between inputs and LazyFields."""

0 commit comments

Comments
 (0)