Skip to content

Commit 8967b58

Browse files
authored
Merge pull request #614 from ghisvail/fix/argstr-any-iterable
ENH: Add support for argstr formatting of any iterable
2 parents 0f909cb + 549c018 commit 8967b58

File tree

2 files changed

+52
-12
lines changed

2 files changed

+52
-12
lines changed

pydra/engine/task.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -371,15 +371,15 @@ def _field_value(self, field, check_file=False):
371371
return value
372372

373373
def _command_shelltask_executable(self, field):
374-
"""Returining position and value for executable ShellTask input"""
374+
"""Returning position and value for executable ShellTask input"""
375375
pos = 0 # executable should be the first el. of the command
376376
value = self._field_value(field)
377377
if value is None:
378378
raise ValueError("executable has to be set")
379379
return pos, ensure_list(value, tuple2list=True)
380380

381381
def _command_shelltask_args(self, field):
382-
"""Returining position and value for args ShellTask input"""
382+
"""Returning position and value for args ShellTask input"""
383383
pos = -1 # assuming that args is the last el. of the command
384384
value = self._field_value(field, check_file=True)
385385
if value is None:
@@ -396,7 +396,7 @@ def _command_pos_args(self, field):
396396
argstr = field.metadata.get("argstr", None)
397397
formatter = field.metadata.get("formatter", None)
398398
if argstr is None and formatter is None:
399-
# assuming that input that has no arstr is not used in the command,
399+
# assuming that input that has no argstr is not used in the command,
400400
# or a formatter is not provided too.
401401
return None
402402
pos = field.metadata.get("position", None)
@@ -429,7 +429,7 @@ def _command_pos_args(self, field):
429429

430430
cmd_add = []
431431
# formatter that creates a custom command argument
432-
# it can thake the value of the filed, all inputs, or the value of other fields.
432+
# it can take the value of the field, all inputs, or the value of other fields.
433433
if "formatter" in field.metadata:
434434
call_args = inspect.getfullargspec(field.metadata["formatter"])
435435
call_args_val = {}
@@ -453,12 +453,16 @@ def _command_pos_args(self, field):
453453
cmd_add += split_cmd(cmd_el_str)
454454
elif field.type is bool:
455455
# if value is simply True the original argstr is used,
456-
# if False, nothing is added to the command
456+
# if False, nothing is added to the command.
457457
if value is True:
458458
cmd_add.append(argstr)
459459
else:
460460
sep = field.metadata.get("sep", " ")
461-
if argstr.endswith("...") and isinstance(value, list):
461+
if (
462+
argstr.endswith("...")
463+
and isinstance(value, ty.Iterable)
464+
and not isinstance(value, (str, bytes))
465+
):
462466
argstr = argstr.replace("...", "")
463467
# if argstr has a more complex form, with "{input_field}"
464468
if "{" in argstr and "}" in argstr:
@@ -474,7 +478,9 @@ def _command_pos_args(self, field):
474478
else:
475479
# in case there are ... when input is not a list
476480
argstr = argstr.replace("...", "")
477-
if isinstance(value, list):
481+
if isinstance(value, ty.Iterable) and not isinstance(
482+
value, (str, bytes)
483+
):
478484
cmd_el_str = sep.join([str(val) for val in value])
479485
value = cmd_el_str
480486
# if argstr has a more complex form, with "{input_field}"
@@ -505,10 +511,10 @@ def cmdline(self):
505511
command_args = self.container_args + self.command_args
506512
else:
507513
command_args = self.command_args
508-
# Skip the executable, which can be a multi-part command, e.g. 'docker run'.
514+
# Skip the executable, which can be a multipart command, e.g. 'docker run'.
509515
cmdline = command_args[0]
510516
for arg in command_args[1:]:
511-
# If there are spaces in the arg and it is not enclosed by matching
517+
# If there are spaces in the arg, and it is not enclosed by matching
512518
# quotes, add quotes to escape the space. Not sure if this should
513519
# be expanded to include other special characters apart from spaces
514520
if " " in arg:
@@ -600,7 +606,7 @@ def __init__(
600606
def _field_value(self, field, check_file=False):
601607
"""
602608
Checking value of the specific field, if value is not set, None is returned.
603-
If check_file is True, checking if field is a a local file
609+
If check_file is True, checking if field is a local file
604610
and settings bindings if needed.
605611
"""
606612
value = super()._field_value(field)
@@ -854,12 +860,12 @@ def split_cmd(cmd: str):
854860
str
855861
the command line string split into process args
856862
"""
857-
# Check whether running on posix or windows system
863+
# Check whether running on posix or Windows system
858864
on_posix = platform.system() != "Windows"
859865
args = shlex.split(cmd, posix=on_posix)
860866
cmd_args = []
861867
for arg in args:
862-
match = re.match("('|\")(.*)\\1$", arg)
868+
match = re.match("(['\"])(.*)\\1$", arg)
863869
if match:
864870
cmd_args.append(match.group(2))
865871
else:

pydra/engine/tests/test_shelltask.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,6 +1674,40 @@ def template_function(inputs):
16741674
assert shelly.output_dir == res.output.file_copy.parent
16751675

16761676

1677+
def test_shell_cmd_inputspec_with_iterable():
1678+
"""Test formatting of argstr with different iterable types."""
1679+
1680+
input_spec = SpecInfo(
1681+
name="Input",
1682+
fields=[
1683+
(
1684+
"iterable_1",
1685+
ty.Iterable[int],
1686+
{
1687+
"help_string": "iterable input 1",
1688+
"argstr": "--in1",
1689+
},
1690+
),
1691+
(
1692+
"iterable_2",
1693+
ty.Iterable[str],
1694+
{
1695+
"help_string": "iterable input 2",
1696+
"argstr": "--in2...",
1697+
},
1698+
),
1699+
],
1700+
bases=(ShellSpec,),
1701+
)
1702+
1703+
task = ShellCommandTask(name="test", input_spec=input_spec, executable="test")
1704+
1705+
for iterable_type in (list, tuple):
1706+
task.inputs.iterable_1 = iterable_type(range(3))
1707+
task.inputs.iterable_2 = iterable_type(["bar", "foo"])
1708+
assert task.cmdline == "test --in1 0 1 2 --in2 bar --in2 foo"
1709+
1710+
16771711
@pytest.mark.parametrize("results_function", [result_no_submitter, result_submitter])
16781712
def test_shell_cmd_inputspec_copyfile_1(plugin, results_function, tmpdir):
16791713
"""shelltask changes a file in place,

0 commit comments

Comments
 (0)