Skip to content

Commit 118b9f3

Browse files
authored
Merge pull request #477 from djarecka/fix/cmdline_tmpl_s
[fix] fixing cmdline for shell tasks with templates and spliter (continuation of #475)
2 parents fdff182 + b2cd945 commit 118b9f3

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

pydra/engine/helpers_file.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -525,54 +525,68 @@ def copyfile_input(inputs, output_dir):
525525

526526

527527
# not sure if this might be useful for Function Task
528-
def template_update(inputs, output_dir, map_copyfiles=None):
528+
def template_update(inputs, output_dir, state_ind=None, map_copyfiles=None):
529529
"""
530530
Update all templates that are present in the input spec.
531531
532532
Should be run when all inputs used in the templates are already set.
533533
534534
"""
535-
dict_ = attr.asdict(inputs)
535+
536+
inputs_dict_st = attr.asdict(inputs)
536537
if map_copyfiles is not None:
537-
dict_.update(map_copyfiles)
538+
inputs_dict_st.update(map_copyfiles)
539+
540+
if state_ind is not None:
541+
for k, v in state_ind.items():
542+
k = k.split(".")[1]
543+
inputs_dict_st[k] = inputs_dict_st[k][v]
538544

539545
from .specs import attr_fields
540546

541547
fields_templ = [
542548
fld for fld in attr_fields(inputs) if fld.metadata.get("output_file_template")
543549
]
550+
dict_mod = {}
544551
for fld in fields_templ:
545552
if fld.type not in [str, ty.Union[str, bool]]:
546553
raise Exception(
547554
f"fields with output_file_template"
548555
" has to be a string or Union[str, bool]"
549556
)
550-
dict_[fld.name] = template_update_single(
551-
field=fld, inputs=inputs, output_dir=output_dir
557+
dict_mod[fld.name] = template_update_single(
558+
field=fld,
559+
inputs=inputs,
560+
inputs_dict_st=inputs_dict_st,
561+
output_dir=output_dir,
552562
)
553-
# using is and == so it covers list and numpy arrays
554-
updated_templ_dict = {
555-
k: v
556-
for k, v in dict_.items()
557-
if not (getattr(inputs, k) is v or getattr(inputs, k) == v)
558-
}
559-
return updated_templ_dict
563+
# adding elements from map_copyfiles to fields with templates
564+
if map_copyfiles:
565+
dict_mod.update(map_copyfiles)
566+
return dict_mod
560567

561568

562-
def template_update_single(field, inputs, output_dir=None, spec_type="input"):
569+
def template_update_single(
570+
field, inputs, inputs_dict_st=None, output_dir=None, spec_type="input"
571+
):
563572
"""Update a single template from the input_spec or output_spec
564573
based on the value from inputs_dict
565574
(checking the types of the fields, that have "output_file_template)"
566575
"""
567576
from .specs import File, MultiOutputFile, Directory
568577

578+
# if input_dict_st with state specific value is not available,
579+
# the dictionary will be created from inputs object
580+
if inputs_dict_st is None:
581+
inputs_dict_st = attr.asdict(inputs)
582+
569583
if spec_type == "input":
570584
if field.type not in [str, ty.Union[str, bool]]:
571585
raise Exception(
572586
f"fields with output_file_template"
573587
"has to be a string or Union[str, bool]"
574588
)
575-
inp_val_set = getattr(inputs, field.name)
589+
inp_val_set = inputs_dict_st[field.name]
576590
if inp_val_set is not attr.NOTHING and not isinstance(inp_val_set, (str, bool)):
577591
raise Exception(
578592
f"{field.name} has to be str or bool, but {inp_val_set} set"
@@ -589,13 +603,13 @@ def template_update_single(field, inputs, output_dir=None, spec_type="input"):
589603
else:
590604
raise Exception(f"spec_type can be input or output, but {spec_type} provided")
591605
# for inputs that the value is set (so the template is ignored)
592-
if spec_type == "input" and isinstance(getattr(inputs, field.name), str):
593-
return getattr(inputs, field.name)
594-
elif spec_type == "input" and getattr(inputs, field.name) is False:
606+
if spec_type == "input" and isinstance(inputs_dict_st[field.name], str):
607+
return inputs_dict_st[field.name]
608+
elif spec_type == "input" and inputs_dict_st[field.name] is False:
595609
# if input fld is set to False, the fld shouldn't be used (setting NOTHING)
596610
return attr.NOTHING
597611
else: # inputs_dict[field.name] is True or spec_type is output
598-
value = _template_formatting(field, inputs)
612+
value = _template_formatting(field, inputs, inputs_dict_st)
599613
# changing path so it is in the output_dir
600614
if output_dir and value is not attr.NOTHING:
601615
# should be converted to str, it is also used for input fields that should be str
@@ -607,7 +621,7 @@ def template_update_single(field, inputs, output_dir=None, spec_type="input"):
607621
return value
608622

609623

610-
def _template_formatting(field, inputs):
624+
def _template_formatting(field, inputs, inputs_dict_st):
611625
"""Formatting the field template based on the values from inputs.
612626
Taking into account that the field with a template can be a MultiOutputFile
613627
and the field values needed in the template can be a list -
@@ -633,7 +647,7 @@ def _template_formatting(field, inputs):
633647

634648
for fld in inp_fields:
635649
fld_name = fld[1:-1] # extracting the name form {field_name}
636-
fld_value = getattr(inputs, fld_name)
650+
fld_value = inputs_dict_st[fld_name]
637651
if fld_value is attr.NOTHING:
638652
# if value is NOTHING, nothing should be added to the command
639653
return attr.NOTHING

pydra/engine/task.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -306,14 +306,18 @@ def command_args(self):
306306
"""Get command line arguments, returns a list if task has a state"""
307307
if is_lazy(self.inputs):
308308
raise Exception("can't return cmdline, self.inputs has LazyFields")
309+
orig_inputs = attr.asdict(self.inputs)
309310
if self.state:
310311
command_args_list = []
311312
self.state.prepare_states(self.inputs)
312313
for ii, el in enumerate(self.state.states_ind):
313314
command_args_list.append(self._command_args_single(el, index=ii))
315+
self.inputs = attr.evolve(self.inputs, **orig_inputs)
314316
return command_args_list
315317
else:
316-
return self._command_args_single()
318+
command_args = self._command_args_single()
319+
self.inputs = attr.evolve(self.inputs, **orig_inputs)
320+
return command_args
317321

318322
def _command_args_single(self, state_ind=None, index=None):
319323
"""Get command line arguments for a single state
@@ -327,7 +331,7 @@ def _command_args_single(self, state_ind=None, index=None):
327331
"""
328332
if index is not None:
329333
modified_inputs = template_update(
330-
self.inputs, output_dir=self.output_dir[index]
334+
self.inputs, output_dir=self.output_dir[index], state_ind=state_ind
331335
)
332336
else:
333337
modified_inputs = template_update(self.inputs, output_dir=self.output_dir)
@@ -476,7 +480,6 @@ def cmdline(self):
476480
raise Exception("can't return cmdline, self.inputs has LazyFields")
477481
# checking the inputs fields before returning the command line
478482
self.inputs.check_fields_input_spec()
479-
orig_inputs = attr.asdict(self.inputs)
480483
if isinstance(self, ContainerTask):
481484
if self.state:
482485
cmdline = []
@@ -492,7 +495,6 @@ def cmdline(self):
492495
else:
493496
cmdline = " ".join(self.command_args)
494497

495-
self.inputs = attr.evolve(self.inputs, **orig_inputs)
496498
return cmdline
497499

498500
def _run_task(self):
@@ -616,6 +618,7 @@ def bind_paths(self, index=None):
616618
else:
617619
output_dir = self.output_dir[index]
618620
for binding in self.inputs.bindings:
621+
binding = list(binding)
619622
if len(binding) == 3:
620623
lpath, cpath, mode = binding
621624
elif len(binding) == 2:

pydra/engine/tests/test_shelltask_inputspec.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import attr
22
import typing as ty
3+
from pathlib import Path
34
import pytest
45

56
from ..task import ShellCommandTask
@@ -1805,14 +1806,19 @@ def test_shell_cmd_inputs_template_1_st():
18051806
bases=(ShellSpec,),
18061807
)
18071808

1809+
inpA = ["inpA_1", "inpA_2"]
18081810
shelly = ShellCommandTask(
18091811
name="f",
18101812
executable="executable",
18111813
input_spec=my_input_spec,
1812-
inpA=["inpA", "inpB"],
1814+
inpA=inpA,
18131815
).split("inpA")
18141816

1815-
assert shelly.cmdline
1817+
cmdline_list = shelly.cmdline
1818+
assert len(cmdline_list) == 2
1819+
for i in range(2):
1820+
path_out = Path(shelly.output_dir[i]) / f"{inpA[i]}_out"
1821+
assert cmdline_list[i] == f"executable {inpA[i]} -o {str(path_out)}"
18161822

18171823

18181824
# TODO: after deciding how we use requires/templates

0 commit comments

Comments
 (0)