Skip to content

Commit d056d53

Browse files
committed
fixing support for multInputObjs
1 parent 27a0404 commit d056d53

File tree

4 files changed

+64
-51
lines changed

4 files changed

+64
-51
lines changed

pydra/design/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def extract_function_inputs_and_outputs(
833833
if isinstance(inpt, arg_type):
834834
if inpt.default is EMPTY:
835835
inpt.default = default
836-
elif inspect.isclass(inpt):
836+
elif inspect.isclass(inpt) or ty.get_origin(inpt):
837837
inputs[inpt_name] = arg_type(type=inpt, default=default)
838838
else:
839839
raise ValueError(

pydra/design/shell.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,25 @@ class arg(Arg):
100100
formatter: ty.Callable | None = None
101101

102102
@sep.validator
103-
def _validate_sep(self, attribute, value):
104-
if (
105-
value is not None
106-
and self.type is not ty.Any
107-
and ty.get_origin(self.type) is not MultiInputObj
108-
):
109-
tp = ty.get_origin(self.type) or self.type
110-
if not issubclass(tp, ty.Iterable):
103+
def _validate_sep(self, _, sep):
104+
if self.type is ty.Any:
105+
return
106+
if ty.get_origin(self.type) is MultiInputObj:
107+
tp = ty.get_args(self.type)[0]
108+
else:
109+
tp = self.type
110+
origin = ty.get_origin(tp) or tp
111+
if inspect.isclass(origin) and issubclass(origin, ty.Iterable):
112+
if sep is None:
111113
raise ValueError(
112-
f"sep ({value!r}) can only be provided when type is iterable"
114+
f"A value to 'sep' must be provided when type is iterable {tp} "
115+
f"for field {self.name!r}"
113116
)
117+
elif sep is not None:
118+
raise ValueError(
119+
f"sep ({sep!r}) can only be provided when type is iterable {tp} "
120+
f"for field {self.name!r}"
121+
)
114122

115123

116124
@attrs.define(kw_only=True)

pydra/engine/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def _check_for_hash_changes(self):
554554
f"- {changed}: the {field_type} object passed to the {field.type}"
555555
f"field appears to have an unstable hash. This could be due to "
556556
"a stochastic/non-thread-safe attribute(s) of the object\n\n"
557-
f"The {field.type}.__bytes_repr__() method can be implemented to "
557+
f'A "bytes_repr" method for {field.type!r} can be implemented to '
558558
"bespoke hashing methods based only on the stable attributes for "
559559
f"the `{field_type.__module__}.{field_type.__name__}` type. "
560560
f"See pydra/utils/hash.py for examples. Value: {val}\n"

pydra/engine/specs.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from . import helpers_state as hlpst
3232
from . import lazy
3333
from pydra.utils.hash import hash_function, Cache
34-
from pydra.utils.typing import StateArray
34+
from pydra.utils.typing import StateArray, MultiInputObj
3535
from pydra.design.base import Field, Arg, Out, RequirementSet, EMPTY
3636
from pydra.design import shell
3737
from pydra.engine.lazy import LazyInField, LazyOutField
@@ -1126,47 +1126,50 @@ def _command_pos_args(
11261126
# if False, nothing is added to the command.
11271127
if value is True:
11281128
cmd_add.append(field.argstr)
1129+
elif ty.get_origin(field.type) is MultiInputObj:
1130+
# if the field is MultiInputObj, it is used to create a list of arguments
1131+
for val in value or []:
1132+
cmd_add += self._format_arg(field, val)
11291133
else:
1130-
if (
1131-
field.argstr.endswith("...")
1132-
and isinstance(value, ty.Iterable)
1133-
and not isinstance(value, (str, bytes))
1134-
):
1135-
field.argstr = field.argstr.replace("...", "")
1136-
# if argstr has a more complex form, with "{input_field}"
1137-
if "{" in field.argstr and "}" in field.argstr:
1138-
argstr_formatted_l = []
1139-
for val in value:
1140-
argstr_f = argstr_formatting(
1141-
field.argstr, self, value_updates={field.name: val}
1142-
)
1143-
argstr_formatted_l.append(f" {argstr_f}")
1144-
cmd_el_str = field.sep.join(argstr_formatted_l)
1145-
else: # argstr has a simple form, e.g. "-f", or "--f"
1146-
cmd_el_str = field.sep.join(
1147-
[f" {field.argstr} {val}" for val in value]
1148-
)
1149-
else:
1150-
# in case there are ... when input is not a list
1151-
field.argstr = field.argstr.replace("...", "")
1152-
if isinstance(value, ty.Iterable) and not isinstance(
1153-
value, (str, bytes)
1154-
):
1155-
cmd_el_str = field.sep.join([str(val) for val in value])
1156-
value = cmd_el_str
1157-
# if argstr has a more complex form, with "{input_field}"
1158-
if "{" in field.argstr and "}" in field.argstr:
1159-
cmd_el_str = field.argstr.replace(f"{{{field.name}}}", str(value))
1160-
cmd_el_str = argstr_formatting(cmd_el_str, self.definition)
1161-
else: # argstr has a simple form, e.g. "-f", or "--f"
1162-
if value:
1163-
cmd_el_str = f"{field.argstr} {value}"
1164-
else:
1165-
cmd_el_str = ""
1166-
if cmd_el_str:
1167-
cmd_add += split_cmd(cmd_el_str)
1134+
cmd_add += self._format_arg(field, value)
11681135
return field.position, cmd_add
11691136

1137+
def _format_arg(self, field: shell.arg, value: ty.Any) -> list[str]:
1138+
"""Returning arguments used to specify the command args for a single inputs"""
1139+
if (
1140+
field.argstr.endswith("...")
1141+
and isinstance(value, ty.Iterable)
1142+
and not isinstance(value, (str, bytes))
1143+
):
1144+
field.argstr = field.argstr.replace("...", "")
1145+
# if argstr has a more complex form, with "{input_field}"
1146+
if "{" in field.argstr and "}" in field.argstr:
1147+
argstr_formatted_l = []
1148+
for val in value:
1149+
argstr_f = argstr_formatting(
1150+
field.argstr, self, value_updates={field.name: val}
1151+
)
1152+
argstr_formatted_l.append(f" {argstr_f}")
1153+
cmd_el_str = field.sep.join(argstr_formatted_l)
1154+
else: # argstr has a simple form, e.g. "-f", or "--f"
1155+
cmd_el_str = field.sep.join([f" {field.argstr} {val}" for val in value])
1156+
else:
1157+
# in case there are ... when input is not a list
1158+
field.argstr = field.argstr.replace("...", "")
1159+
if isinstance(value, ty.Iterable) and not isinstance(value, (str, bytes)):
1160+
cmd_el_str = field.sep.join([str(val) for val in value])
1161+
value = cmd_el_str
1162+
# if argstr has a more complex form, with "{input_field}"
1163+
if "{" in field.argstr and "}" in field.argstr:
1164+
cmd_el_str = field.argstr.replace(f"{{{field.name}}}", str(value))
1165+
cmd_el_str = argstr_formatting(cmd_el_str, self)
1166+
else: # argstr has a simple form, e.g. "-f", or "--f"
1167+
if value:
1168+
cmd_el_str = f"{field.argstr} {value}"
1169+
else:
1170+
cmd_el_str = ""
1171+
return split_cmd(cmd_el_str)
1172+
11701173
def _get_bindings(self, root: str | None = None) -> dict[str, tuple[str, str]]:
11711174
"""Return bindings necessary to run task in an alternative root.
11721175
@@ -1259,7 +1262,7 @@ def reset(self):
12591262
setattr(self, val, donothing)
12601263

12611264

1262-
def split_cmd(cmd: str):
1265+
def split_cmd(cmd: str | None):
12631266
"""Splits a shell command line into separate arguments respecting quotes
12641267
12651268
Parameters
@@ -1272,6 +1275,8 @@ def split_cmd(cmd: str):
12721275
str
12731276
the command line string split into process args
12741277
"""
1278+
if cmd is None:
1279+
return []
12751280
# Check whether running on posix or Windows system
12761281
on_posix = platform.system() != "Windows"
12771282
args = shlex.split(cmd, posix=on_posix)

0 commit comments

Comments
 (0)