Skip to content

Commit f62fd0d

Browse files
committed
debugging test_shelltask, reworking resolve of lazy inputs
1 parent ee88000 commit f62fd0d

File tree

5 files changed

+101
-26
lines changed

5 files changed

+101
-26
lines changed

pydra/engine/helpers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
from filelock import SoftFileLock, Timeout
1717
import cloudpickle as cp
1818
from fileformats.core import FileSet
19+
from pydra.utils.typing import StateArray
20+
1921

2022
if ty.TYPE_CHECKING:
2123
from .specs import TaskDef, Result, WorkflowOutputs, WorkflowDef
2224
from .core import Task
2325
from pydra.design.base import Field
26+
from pydra.engine.lazy import LazyField
2427

2528

2629
PYDRA_ATTR_METADATA = "__PYDRA_METADATA__"
@@ -695,3 +698,27 @@ def is_lazy(obj):
695698
from pydra.engine.lazy import LazyField
696699

697700
return isinstance(obj, LazyField)
701+
702+
703+
T = ty.TypeVar("T")
704+
U = ty.TypeVar("U")
705+
706+
707+
def state_array_support(
708+
function: ty.Callable[T, U],
709+
) -> ty.Callable[T | StateArray[T], U | StateArray[U]]:
710+
"""
711+
Decorator to convert a allow a function to accept and return StateArray objects,
712+
where the function is applied to each element of the StateArray.
713+
"""
714+
715+
def state_array_wrapper(
716+
value: "T | StateArray[T] | LazyField[T]",
717+
) -> "U | StateArray[U] | LazyField[U]":
718+
if is_lazy(value):
719+
return value
720+
if isinstance(value, StateArray):
721+
return StateArray(function(v) for v in value)
722+
return function(value)
723+
724+
return state_array_wrapper

pydra/engine/specs.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
import attrs
1717
from attrs.converters import default_if_none
1818
import cloudpickle as cp
19-
from fileformats.generic import FileSet
19+
from fileformats.generic import FileSet, File
2020
from pydra.utils.messenger import AuditFlag, Messenger
21-
from pydra.utils.typing import is_optional, optional_type, MultiInputObj
21+
from pydra.utils.typing import is_optional, optional_type
2222
from .helpers import (
2323
attrs_fields,
2424
attrs_values,
@@ -28,6 +28,7 @@
2828
ensure_list,
2929
parse_format_string,
3030
fields_in_formatter,
31+
state_array_support,
3132
)
3233
from .helpers_file import template_update, template_update_single
3334
from . import helpers_state as hlpst
@@ -1032,17 +1033,29 @@ def _resolve_value(
10321033
ShellOutputsType = ty.TypeVar("OutputType", bound=ShellOutputs)
10331034

10341035

1036+
@state_array_support
1037+
def additional_args_converter(value: ty.Any) -> list[str]:
1038+
"""Convert additional arguments to a list of strings."""
1039+
if isinstance(value, str):
1040+
return shlex.split(value)
1041+
if not isinstance(value, ty.Sequence):
1042+
return [value]
1043+
return list(value)
1044+
1045+
10351046
@attrs.define(kw_only=True, auto_attribs=False, eq=False)
10361047
class ShellDef(TaskDef[ShellOutputsType]):
10371048

10381049
_task_type = "shell"
10391050

10401051
BASE_NAMES = ["additional_args"]
10411052

1042-
additional_args: MultiInputObj[str] = shell.arg(
1053+
additional_args: list[str | File] = shell.arg(
10431054
name="additional_args",
10441055
default=attrs.Factory(list),
1045-
type=MultiInputObj[str],
1056+
converter=additional_args_converter,
1057+
type=list[str | File],
1058+
sep=" ",
10461059
help="Additional free-form arguments to append to the end of the command.",
10471060
)
10481061

pydra/engine/state.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,18 @@ class StateIndex:
2727

2828
indices: OrderedDict[str, int]
2929

30-
def __init__(self, indices: dict[str, int] | None = None):
30+
def __init__(
31+
self, indices: dict[str, int] | ty.Sequence[tuple[str, int]] | None = None
32+
):
3133
# We used ordered dict here to ensure the keys are always in the same order
3234
# while OrderedDict is not strictly necessary for CPython 3.7+, we use it to
3335
# signal that the order of the keys is important
3436
if indices is None:
3537
self.indices = OrderedDict()
3638
else:
37-
self.indices = OrderedDict(sorted(indices.items()))
39+
if isinstance(indices, dict):
40+
indices = indices.items()
41+
self.indices = OrderedDict(sorted(indices))
3842

3943
def __len__(self) -> int:
4044
return len(self.indices)

pydra/engine/submitter.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pydra.utils.hash import PersistentCache
2121
from .state import StateIndex
2222
from pydra.utils.typing import StateArray
23-
from pydra.engine.lazy import LazyField
23+
from pydra.engine.lazy import LazyField, LazyOutField
2424
from .audit import Audit
2525
from .core import Task
2626
from pydra.utils.messenger import AuditFlag, Messenger
@@ -537,6 +537,10 @@ def __init__(
537537
self.queued = {}
538538
self.running = {} # Not used in logic, but may be useful for progress tracking
539539
self.unrunnable = defaultdict(list)
540+
# Prepare the state to be run
541+
if self.state:
542+
self.state.prepare_states(self.node.state_values)
543+
self.state.prepare_inputs()
540544
self.state_names = self.node.state.names if self.node.state else []
541545
self.workflow = workflow
542546
self.graph = None
@@ -567,6 +571,21 @@ def tasks(self) -> ty.Iterable["Task[DefType]"]:
567571
self._tasks = {t.state_index: t for t in self._generate_tasks()}
568572
return self._tasks.values()
569573

574+
def translate_index(self, index: StateIndex, lf: LazyOutField):
575+
state_key = f"{lf._node.name}.{lf._field}"
576+
try:
577+
upstream_state = self.state.inner_inputs[state_key]
578+
except KeyError:
579+
state_index = StateIndex(index)
580+
else:
581+
state_index = StateIndex(
582+
zip(
583+
upstream_state.keys_final,
584+
upstream_state.ind_l_final[index[state_key]],
585+
)
586+
)
587+
return state_index
588+
570589
def matching_jobs(self, index: StateIndex = StateIndex()) -> "StateArray[Task]":
571590
"""Get the jobs that match a given state index.
572591
@@ -702,23 +721,22 @@ def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
702721
if not self.node.state:
703722
return {None: self.node._definition}
704723
split_defs = {}
705-
self.state.prepare_states(self.node.state_values)
706-
self.state.prepare_inputs()
707724
for input_ind in self.node.state.inputs_ind:
708725
resolved = {}
709726
for inpt_name in set(self.node.input_names):
710727
value = getattr(self._definition, inpt_name)
728+
state_key = f"{self.node.name}.{inpt_name}"
711729
if isinstance(value, LazyField):
712730
value = resolved[inpt_name] = value._get_value(
713731
workflow=self.workflow,
714732
graph=self.graph,
715-
state_index=StateIndex(input_ind),
733+
state_index=input_ind,
716734
)
717-
if f"{self.node.name}.{inpt_name}" in input_ind:
735+
elif state_key in input_ind:
718736
resolved[inpt_name] = self.node.state._get_element(
719737
value=value,
720738
field_name=inpt_name,
721-
ind=input_ind[f"{self.node.name}.{inpt_name}"],
739+
ind=input_ind[state_key],
722740
)
723741
split_defs[StateIndex(input_ind)] = attrs.evolve(
724742
self.node._definition, **resolved

pydra/engine/tests/test_shelltask.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1674,9 +1674,11 @@ def test_wf_shell_cmd_3(plugin, tmp_path):
16741674
class Shelly1(ShellDef["Shelly1.Outputs"]):
16751675
executable = "shelly"
16761676

1677+
arg: str = shell.arg(argstr=None)
1678+
16771679
class Outputs(ShellOutputs):
16781680
file: File = shell.outarg(
1679-
path_template="{args}",
1681+
path_template="{arg}",
16801682
help="output file",
16811683
)
16821684

@@ -1700,12 +1702,12 @@ class Outputs(ShellOutputs):
17001702
)
17011703

17021704
@workflow.define(outputs=["touch_file", "out1", "cp_file", "out2"])
1703-
def Workflow(cmd1, cmd2, args):
1705+
def Workflow(cmd1, cmd2, arg):
17041706

17051707
shelly1 = workflow.add(
17061708
Shelly1(
17071709
executable=cmd1,
1708-
additional_args=args,
1710+
arg=arg,
17091711
)
17101712
)
17111713
shelly2 = workflow.add(
@@ -1717,9 +1719,9 @@ def Workflow(cmd1, cmd2, args):
17171719

17181720
return shelly1.file, shelly1.stdout, shelly2.out_file, shelly2.stdout
17191721

1720-
wf = Workflow(cmd1="touch", cmd2="cp", args=File.mock("newfile.txt"))
1722+
wf = Workflow(cmd1="touch", cmd2="cp", arg="newfile.txt")
17211723

1722-
with Submitter(plugin="debug") as sub:
1724+
with Submitter(plugin="debug", cache_dir=tmp_path) as sub:
17231725
res = sub(wf)
17241726

17251727
assert res.outputs.out1 == ""
@@ -1738,14 +1740,19 @@ def test_wf_shell_cmd_3a(plugin, tmp_path):
17381740

17391741
@shell.define
17401742
class Shelly1(ShellDef["Shelly1.Outputs"]):
1743+
executable = "shelly"
1744+
arg: str = shell.outarg(argstr=None)
1745+
17411746
class Outputs(ShellOutputs):
1747+
17421748
file: File = shell.outarg(
1743-
path_template="{args}",
1749+
path_template="{arg}",
17441750
help="output file",
17451751
)
17461752

17471753
@shell.define
17481754
class Shelly2(ShellDef["Shelly2.Outputs"]):
1755+
executable = "shelly2"
17491756
orig_file: str = shell.arg(
17501757
position=1,
17511758
help="output file",
@@ -1761,12 +1768,12 @@ class Outputs(ShellOutputs):
17611768
)
17621769

17631770
@workflow.define(outputs=["touch_file", "out1", "cp_file", "out2"])
1764-
def Workflow(cmd1, cmd2, args):
1771+
def Workflow(cmd1, cmd2, arg):
17651772

17661773
shelly1 = workflow.add(
17671774
Shelly1(
17681775
executable=cmd1,
1769-
additional_args=args,
1776+
arg=arg,
17701777
)
17711778
)
17721779
shelly2 = workflow.add(
@@ -1778,7 +1785,7 @@ def Workflow(cmd1, cmd2, args):
17781785

17791786
return shelly1.file, shelly1.stdout, shelly2.out_file, shelly2.stdout
17801787

1781-
wf = Workflow(cmd1="touch", cmd2="cp", args=File.mock("newfile.txt"))
1788+
wf = Workflow(cmd1="touch", cmd2="cp", arg="newfile.txt")
17821789

17831790
with Submitter(plugin="debug") as sub:
17841791
res = sub(wf)
@@ -1861,14 +1868,20 @@ def test_wf_shell_cmd_ndst_1(plugin, tmp_path):
18611868

18621869
@shell.define
18631870
class Shelly1(ShellDef["Shelly1.Outputs"]):
1871+
executable = "shelly"
1872+
1873+
arg: str = shell.arg(argstr=None)
1874+
18641875
class Outputs(ShellOutputs):
18651876
file: File = shell.outarg(
1866-
path_template="{args}",
1877+
path_template="{arg}",
18671878
help="output file",
18681879
)
18691880

18701881
@shell.define
18711882
class Shelly2(ShellDef["Shelly2.Outputs"]):
1883+
executable = "shelly2"
1884+
18721885
orig_file: str = shell.arg(
18731886
position=1,
18741887
help="output file",
@@ -1889,7 +1902,7 @@ def Workflow(cmd1, cmd2, args):
18891902
shelly1 = workflow.add(
18901903
Shelly1(
18911904
executable=cmd1,
1892-
).split("args", args=args)
1905+
).split("arg", arg=args)
18931906
)
18941907
shelly2 = workflow.add(
18951908
Shelly2(
@@ -1903,10 +1916,10 @@ def Workflow(cmd1, cmd2, args):
19031916
wf = Workflow(
19041917
cmd1="touch",
19051918
cmd2="cp",
1906-
args=[File.mock("newfile_1.txt"), File.mock("newfile_2.txt")],
1919+
args=["newfile_1.txt", "newfile_2.txt"],
19071920
)
19081921

1909-
with Submitter(plugin="debug") as sub:
1922+
with Submitter(plugin="debug", cache_dir=tmp_path) as sub:
19101923
res = sub(wf)
19111924

19121925
assert res.outputs.out1 == ["", ""]
@@ -3288,7 +3301,7 @@ class Outputs(ShellOutputs):
32883301
# An exception should be raised because the second mandatory output does not exist
32893302
with pytest.raises(
32903303
ValueError,
3291-
match=r"file system path provided to mandatory field .* does not exist",
3304+
match=r"file system path\(s\) provided to mandatory field .* does not exist",
32923305
):
32933306
shelly(cache_dir=tmp_path)
32943307
# checking if the first output was created

0 commit comments

Comments
 (0)