Skip to content

Commit 4be0945

Browse files
committed
added in "modify|" syntax as discussed with @satra
1 parent 0cf7d35 commit 4be0945

File tree

4 files changed

+82
-18
lines changed

4 files changed

+82
-18
lines changed

pydra/design/shell.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def parse_command_line_template(
461461
return template, inputs, outputs
462462
executable, args_str = parts
463463
tokens = re.split(r"\s+", args_str.strip())
464-
arg_pattern = r"<([:a-zA-Z0-9_,\|\-\.\/\+]+\??)>"
464+
arg_pattern = r"<([:a-zA-Z0-9_,\|\-\.\/\+]+(?:\?|=[^>]+)?)>"
465465
opt_pattern = r"--?[a-zA-Z0-9_]+"
466466
arg_re = re.compile(arg_pattern)
467467
opt_re = re.compile(opt_pattern)
@@ -470,10 +470,8 @@ def parse_command_line_template(
470470
arguments = []
471471
option = None
472472

473-
def add_arg(name, field_type, kwds, is_option=False):
473+
def add_arg(name, field_type, kwds):
474474
"""Merge the typing information with an existing field if it exists"""
475-
if is_option and kwds["type"] is not bool:
476-
kwds["type"] |= None
477475
if issubclass(field_type, Out):
478476
dct = outputs
479477
else:
@@ -497,7 +495,8 @@ def add_arg(name, field_type, kwds, is_option=False):
497495
for k, v in kwds.items():
498496
setattr(field, k, v)
499497
dct[name] = field
500-
arguments.append(field)
498+
if issubclass(field_type, Arg):
499+
arguments.append(field)
501500

502501
def from_type_str(type_str) -> type:
503502
types = []
@@ -528,9 +527,14 @@ def from_type_str(type_str) -> type:
528527
for token in tokens:
529528
if match := arg_re.match(token):
530529
name = match.group(1)
530+
modify = False
531531
if name.startswith("out|"):
532532
name = name[4:]
533533
field_type = outarg
534+
elif name.startswith("modify|"):
535+
name = name[7:]
536+
field_type = arg
537+
modify = True
534538
else:
535539
field_type = arg
536540
# Identify type after ':' symbols
@@ -539,14 +543,22 @@ def from_type_str(type_str) -> type:
539543
optional = True
540544
else:
541545
optional = False
546+
kwds = {}
547+
if "=" in name:
548+
name, default = name.split("=")
549+
kwds["default"] = eval(default)
542550
if ":" in name:
543551
name, type_str = name.split(":")
544552
type_ = from_type_str(type_str)
545553
else:
546554
type_ = generic.FsObject if option is None else str
547555
if optional:
548556
type_ |= None # Make the arguments optional
549-
kwds = {"type": type_}
557+
kwds["type"] = type_
558+
if modify:
559+
kwds["copy_mode"] = generic.File.CopyMode.copy
560+
# Add field to outputs with the same name as the input
561+
add_arg(name, out, {"type": type_, "callable": _InputPassThrough(name)})
550562
# If name contains a '.', treat it as a file template and strip it from the name
551563
if field_type is outarg:
552564
path_template = name
@@ -566,6 +578,7 @@ def from_type_str(type_str) -> type:
566578
kwds["argstr"] = option
567579
add_arg(name, field_type, kwds)
568580
option = None
581+
569582
elif match := bool_arg_re.match(token):
570583
argstr, var = match.groups()
571584
add_arg(var, arg, {"type": bool, "argstr": argstr, "default": False})
@@ -626,3 +639,13 @@ def remaining_positions(
626639
f"Multiple fields have the overlapping positions: {multiple_positions}"
627640
)
628641
return [i for i in range(start, num_args) if i not in positions]
642+
643+
644+
@attrs.define
645+
class _InputPassThrough:
646+
"""A class that can be used to pass through an input to the output"""
647+
648+
name: str
649+
650+
def __call__(self, inputs: ShellSpec) -> ty.Any:
651+
return getattr(inputs, self.name)

pydra/design/tests/test_shell.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,34 @@ def test_interface_template_w_types_and_path_template_ext():
7171
SampleInterface.Outputs(out_image=image.Png.mock())
7272

7373

74+
def test_interface_template_w_modify():
75+
76+
SampleInterface = shell.define("trim-png <modify|image:image/png>")
77+
78+
assert issubclass(SampleInterface, ShellSpec)
79+
assert sorted_fields(SampleInterface) == [
80+
shell.arg(
81+
name="executable",
82+
default="trim-png",
83+
type=str | ty.Sequence[str],
84+
position=0,
85+
help_string=shell.EXECUTABLE_HELP_STRING,
86+
),
87+
shell.arg(
88+
name="image", type=image.Png, position=1, copy_mode=File.CopyMode.copy
89+
),
90+
]
91+
assert sorted_fields(SampleInterface.Outputs) == [
92+
shell.out(
93+
name="image",
94+
type=image.Png,
95+
callable=shell._InputPassThrough("image"),
96+
)
97+
]
98+
SampleInterface(image=image.Png.mock())
99+
SampleInterface.Outputs(image=image.Png.mock())
100+
101+
74102
def test_interface_template_more_complex():
75103

76104
SampleInterface = shell.define(

pydra/engine/state.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def __str__(self):
111111
f"and combiner: {self.combiner}"
112112
)
113113

114+
@property
115+
def depth(self):
116+
"""Return the number of uncombined splits of the state."""
117+
return len(self.states_ind)
118+
114119
@property
115120
def splitter(self):
116121
"""Get the splitter of the state."""

pydra/engine/workflow/node.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,26 @@ def inputs(self) -> Inputs:
8181

8282
@property
8383
def state(self):
84+
"""Initialise the state of the node just after it has been created (i.e. before
85+
it has been split or combined) based on the upstream connections
86+
"""
8487
if self._state is not NOT_SET:
8588
return self._state
89+
upstream_states = self._upstream_states()
90+
if upstream_states:
91+
state = State(
92+
self.name,
93+
splitter=None,
94+
other_states=upstream_states,
95+
combiner=None,
96+
)
97+
else:
98+
state = None
99+
self._state = state
100+
return state
101+
102+
def _upstream_states(self):
103+
"""Get the states of the upstream nodes that are connected to this node"""
86104
upstream_states = {}
87105
for inpt_name, val in self.input_values:
88106
if isinstance(val, lazy.LazyOutField) and val.node.state:
@@ -97,17 +115,7 @@ def state(self):
97115
# if the task already exist in other_state,
98116
# additional field name should be added to the list of fields
99117
upstream_states[node.name][1].append(inpt_name)
100-
if upstream_states:
101-
state = State(
102-
node.name,
103-
splitter=None,
104-
other_states=upstream_states,
105-
combiner=None,
106-
)
107-
else:
108-
state = None
109-
self._state = state
110-
return state
118+
return upstream_states
111119

112120
@property
113121
def input_values(self) -> tuple[tuple[str, ty.Any]]:
@@ -248,7 +256,7 @@ def combine(
248256
raise Exception("combiner has to be a string or a list")
249257
combiner = hlpst.add_name_combiner(ensure_list(combiner), self.name)
250258
if not_split := [
251-
c for c in combiner if not any(c in s for s in self.state.splitter)
259+
c for c in combiner if not any(c in s for s in self.state.splitter_rpn)
252260
]:
253261
raise ValueError(
254262
f"Combiner fields {not_split} for Node {self.name!r} are not in the "

0 commit comments

Comments
 (0)