Skip to content

Commit 7b3aeb5

Browse files
committed
fixed splitting of tasks over states where value comes from lazy field
1 parent 75a2983 commit 7b3aeb5

File tree

7 files changed

+106
-133
lines changed

7 files changed

+106
-133
lines changed

pydra/engine/core.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -916,47 +916,6 @@ def _create_graph(
916916
)
917917
return graph
918918

919-
def create_dotfile(self, type="simple", export=None, name=None, output_dir=None):
920-
"""creating a graph - dotfile and optionally exporting to other formats"""
921-
outdir = output_dir if output_dir is not None else self.cache_dir
922-
graph = self.graph
923-
if not name:
924-
name = f"graph_{self._node.name}"
925-
if type == "simple":
926-
for task in graph.nodes:
927-
self.create_connections(task)
928-
dotfile = graph.create_dotfile_simple(outdir=outdir, name=name)
929-
elif type == "nested":
930-
for task in graph.nodes:
931-
self.create_connections(task)
932-
dotfile = graph.create_dotfile_nested(outdir=outdir, name=name)
933-
elif type == "detailed":
934-
# create connections with detailed=True
935-
for task in graph.nodes:
936-
self.create_connections(task, detailed=True)
937-
# adding wf outputs
938-
for wf_out, lf in self._connections:
939-
graph.add_edges_description(
940-
(self._node.name, wf_out, lf._node.name, lf.field)
941-
)
942-
dotfile = graph.create_dotfile_detailed(outdir=outdir, name=name)
943-
else:
944-
raise Exception(
945-
f"type of the graph can be simple, detailed or nested, "
946-
f"but {type} provided"
947-
)
948-
if not export:
949-
return dotfile
950-
else:
951-
if export is True:
952-
export = ["png"]
953-
elif isinstance(export, str):
954-
export = [export]
955-
formatted_dot = []
956-
for ext in export:
957-
formatted_dot.append(graph.export_graph(dotfile=dotfile, ext=ext))
958-
return dotfile, formatted_dot
959-
960919

961920
def is_workflow(obj):
962921
"""Check whether an object is a :class:`Workflow` instance."""

pydra/engine/helpers.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from fileformats.core import FileSet
1919

2020
if ty.TYPE_CHECKING:
21-
from .specs import TaskDef, Result, WorkflowOutputs
21+
from .specs import TaskDef, Result, WorkflowOutputs, WorkflowDef
2222
from .core import Task
2323
from pydra.design.base import Field
2424

@@ -28,6 +28,61 @@
2828
DefType = ty.TypeVar("DefType", bound="TaskDef")
2929

3030

31+
def plot_workflow(
32+
workflow_task: "WorkflowDef",
33+
out_dir: Path,
34+
type="simple",
35+
export=None,
36+
name=None,
37+
output_dir=None,
38+
):
39+
"""creating a graph - dotfile and optionally exporting to other formats"""
40+
from .core import Workflow
41+
42+
# Create output directory
43+
out_dir.mkdir(parents=True, exist_ok=True)
44+
45+
# Construct the workflow object
46+
wf = Workflow.construct(workflow_task)
47+
graph = wf.graph
48+
if not name:
49+
name = f"graph_{wf._node.name}"
50+
if type == "simple":
51+
for task in graph.nodes:
52+
wf.create_connections(task)
53+
dotfile = graph.create_dotfile_simple(outdir=out_dir, name=name)
54+
elif type == "nested":
55+
for task in graph.nodes:
56+
wf.create_connections(task)
57+
dotfile = graph.create_dotfile_nested(outdir=out_dir, name=name)
58+
elif type == "detailed":
59+
# create connections with detailed=True
60+
for task in graph.nodes:
61+
wf.create_connections(task, detailed=True)
62+
# adding wf outputs
63+
for wf_out, lf in wf._connections:
64+
graph.add_edges_description(
65+
(wf._node.name, wf_out, lf._node.name, lf.field)
66+
)
67+
dotfile = graph.create_dotfile_detailed(outdir=out_dir, name=name)
68+
else:
69+
raise Exception(
70+
f"type of the graph can be simple, detailed or nested, "
71+
f"but {type} provided"
72+
)
73+
if not export:
74+
return dotfile
75+
else:
76+
if export is True:
77+
export = ["png"]
78+
elif isinstance(export, str):
79+
export = [export]
80+
formatted_dot = []
81+
for ext in export:
82+
formatted_dot.append(graph.export_graph(dotfile=dotfile, ext=ext))
83+
return dotfile, formatted_dot
84+
85+
3186
def attrs_fields(definition, exclude_names=()) -> list[attrs.Attribute]:
3287
"""Get the fields of a definition, excluding some names."""
3388
return [

pydra/engine/node.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import typing as ty
2-
from copy import deepcopy, copy
2+
from copy import deepcopy
33
from enum import Enum
44
import attrs
55
from pydra.utils.typing import TypeParser, StateArray
66
from . import lazy
77
from pydra.engine.helpers import (
8-
ensure_list,
98
attrs_values,
109
is_lazy,
11-
create_checksum,
1210
)
13-
from pydra.utils.hash import hash_function
1411
from pydra.engine import helpers_state as hlpst
15-
from pydra.engine.state import State, StateIndex
12+
from pydra.engine.state import State
1613

1714
if ty.TYPE_CHECKING:
1815
from .core import Workflow
@@ -172,51 +169,6 @@ def combiner(self):
172169
return ()
173170
return self._state.combiner
174171

175-
def _checksum_states(self, state_index: StateIndex = StateIndex()):
176-
"""
177-
Calculate a checksum for the specific state or all of the states of the task.
178-
Replaces state-arrays in the inputs fields with a specific values for states.
179-
Used to recreate names of the task directories,
180-
181-
Parameters
182-
----------
183-
state_index :
184-
TODO
185-
186-
"""
187-
# if is_workflow(self) and self._definition._graph_checksums is attr.NOTHING:
188-
# self._definition._graph_checksums = {
189-
# nd.name: nd.checksum for nd in self.graph_sorted
190-
# }
191-
from pydra.engine.specs import WorkflowDef
192-
193-
if state_index:
194-
inputs_copy = copy(self._definition)
195-
for key, ind in self.state.inputs_ind[state_index].items():
196-
val = self._extract_input_el(
197-
inputs=self._definition, inp_nm=key.split(".")[1], ind=ind
198-
)
199-
setattr(inputs_copy, key.split(".")[1], val)
200-
# setting files_hash again in case it was cleaned by setting specific element
201-
# that might be important for outer splitter of input variable with big files
202-
# the file can be changed with every single index even if there are only two files
203-
input_hash = inputs_copy.hash
204-
if isinstance(self._definition, WorkflowDef):
205-
con_hash = hash_function(self._connections)
206-
# TODO: hash list is not used
207-
hash_list = [input_hash, con_hash] # noqa: F841
208-
checksum_ind = create_checksum(
209-
self.__class__.__name__, self._checksum_wf(input_hash)
210-
)
211-
else:
212-
checksum_ind = create_checksum(self.__class__.__name__, input_hash)
213-
return checksum_ind
214-
else:
215-
checksum_list = []
216-
for ind in range(len(self.state.inputs_ind)):
217-
checksum_list.append(self._checksum_states(state_index=ind))
218-
return checksum_list
219-
220172
def _check_if_outputs_have_been_used(self, msg):
221173
used = []
222174
if self._lzout:
@@ -287,24 +239,6 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
287239
upstream_states[node.name][1].append(inpt_name)
288240
return upstream_states
289241

290-
def _extract_input_el(self, inputs, inp_nm, ind):
291-
"""
292-
Extracting element of the inputs taking into account
293-
container dimension of the specific element that can be set in self.state.cont_dim.
294-
If input name is not in cont_dim, it is assumed that the input values has
295-
a container dimension of 1, so only the most outer dim will be used for splitting.
296-
If
297-
"""
298-
if f"{self.name}.{inp_nm}" in self.state.cont_dim:
299-
return list(
300-
hlpst.flatten(
301-
ensure_list(getattr(inputs, inp_nm)),
302-
max_depth=self.state.cont_dim[f"{self.name}.{inp_nm}"],
303-
)
304-
)[ind]
305-
else:
306-
return getattr(inputs, inp_nm)[ind]
307-
308242
# else:
309243
# # todo it never gets here
310244
# breakpoint()

pydra/engine/state.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,3 +1253,29 @@ def _single_op_splits(self, op_single):
12531253
val = op["*"](val_ind)
12541254
keys = [op_single]
12551255
return val, keys
1256+
1257+
def _get_element(self, value: ty.Any, field_name: str, ind: int):
1258+
"""
1259+
Extracting element of the inputs taking into account
1260+
container dimension of the specific element that can be set in self.state.cont_dim.
1261+
If input name is not in cont_dim, it is assumed that the input values has
1262+
a container dimension of 1, so only the most outer dim will be used for splitting.
1263+
1264+
Parameters
1265+
----------
1266+
value : Any
1267+
inputs of the task
1268+
field_name : str
1269+
name of the input field
1270+
ind : int
1271+
index of the element
1272+
"""
1273+
if f"{self.name}.{field_name}" in self.cont_dim:
1274+
return list(
1275+
hlpst.flatten(
1276+
ensure_list(value),
1277+
max_depth=self.cont_dim[f"{self.name}.{field_name}"],
1278+
)
1279+
)[ind]
1280+
else:
1281+
return value[ind]

pydra/engine/submitter.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -690,24 +690,23 @@ def _split_definition(self) -> dict[StateIndex, "TaskDef[OutputType]"]:
690690
return {None: self.node._definition}
691691
split_defs = {}
692692
for input_ind in self.node.state.inputs_ind:
693-
inputs_dict = {}
693+
resolved = {}
694694
for inp in set(self.node.input_names):
695+
value = getattr(self.node._definition, inp)
696+
if isinstance(value, LazyField):
697+
value = resolved[inp] = value._get_value(
698+
workflow=self.workflow,
699+
graph=self.graph,
700+
state_index=StateIndex(input_ind),
701+
)
695702
if f"{self.node.name}.{inp}" in input_ind:
696-
value = getattr(self.node._definition, inp)
697-
if isinstance(value, LazyField):
698-
inputs_dict[inp] = value._get_value(
699-
workflow=self.workflow,
700-
graph=self.graph,
701-
state_index=StateIndex(input_ind),
702-
)
703-
else:
704-
inputs_dict[inp] = self.node._extract_input_el(
705-
inputs=self.node._definition,
706-
inp_nm=inp,
707-
ind=input_ind[f"{self.node.name}.{inp}"],
708-
)
703+
resolved[inp] = self.node.state._get_element(
704+
value=value,
705+
field_name=inp,
706+
ind=input_ind[f"{self.node.name}.{inp}"],
707+
)
709708
split_defs[StateIndex(input_ind)] = attrs.evolve(
710-
self.node._definition, **inputs_dict
709+
self.node._definition, **resolved
711710
)
712711
return split_defs
713712

pydra/engine/tests/test_numpy_examples.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,7 @@ def test_task_numpyinput_1(tmp_path: Path):
8181
nn = Identity().split(x=[np.array([1, 2]), np.array([3, 4])])
8282
# checking the results
8383
outputs = nn(cache_dir=tmp_path)
84-
assert (outputs.out[0] == np.array([1, 2])).all()
85-
assert (outputs.out[1] == np.array([3, 4])).all()
84+
assert (np.array(outputs.out) == np.array([[1, 2], [3, 4]])).all()
8685

8786

8887
def test_task_numpyinput_2(tmp_path: Path):

pydra/engine/tests/test_workflow.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from pydra.engine.submitter import Submitter
3939
from pydra.design import python, workflow
40+
import pydra.engine.core
4041
from pydra.utils import exc_info_matches
4142

4243

@@ -959,8 +960,7 @@ def Workflow(x, y):
959960

960961
assert not results.errored, "\n".join(results.errors["error message"])
961962

962-
assert results.outputs.out[0] == [13, 24, 35]
963-
assert results.outputs.out[1] == [14, 26, 38]
963+
assert results.outputs.out == [[13, 24, 35], [14, 26, 38]]
964964

965965

966966
def test_wf_ndst_7(plugin, tmpdir):
@@ -3735,13 +3735,14 @@ def Workflow1(x, y):
37353735
def create_tasks():
37363736
@workflow.define
37373737
def Workflow(x):
3738-
t1 = workflow.add(Add2(x=x))
3739-
t2 = workflow.add(Multiply(x=t1.out, y=2))
3738+
t1 = workflow.add(Add2(x=x), name="t1")
3739+
t2 = workflow.add(Multiply(x=t1.out, y=2), name="t2")
37403740
return t2.out
37413741

37423742
wf = Workflow(x=1)
3743-
t1 = wf.name2obj["t1"]
3744-
t2 = wf.name2obj["t2"]
3743+
workflow_obj = pydra.engine.core.Workflow.construct(wf)
3744+
t1 = workflow_obj["t1"]
3745+
t2 = workflow_obj["t2"]
37453746
return wf, t1, t2
37463747

37473748

0 commit comments

Comments
 (0)