Skip to content

Commit 3de0fa0

Browse files
committed
debugged workflow graph plotting
1 parent e276608 commit 3de0fa0

File tree

6 files changed

+164
-108
lines changed

6 files changed

+164
-108
lines changed

pydra/engine/core.py

Lines changed: 64 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import attrs
1919
from filelock import SoftFileLock
2020
from pydra.engine.specs import TaskDef, WorkflowDef, TaskOutputs, WorkflowOutputs
21-
from pydra.engine.graph import DiGraph
21+
from pydra.engine.graph import DiGraph, INPUTS_NODE_NAME, OUTPUTS_NODE_NAME
2222
from pydra.engine import state
2323
from .lazy import LazyInField, LazyOutField
2424
from pydra.utils.hash import hash_function, Cache
@@ -598,16 +598,32 @@ def clear_cache(
598598

599599
@classmethod
600600
def construct(
601-
cls, definition: WorkflowDef[WorkflowOutputsType], dont_cache: bool = False
601+
cls,
602+
definition: WorkflowDef[WorkflowOutputsType],
603+
dont_cache: bool = False,
604+
lazy: ty.Sequence[str] = (),
602605
) -> Self:
603-
"""Construct a workflow from a definition, caching the constructed worklow"""
606+
"""Construct a workflow from a definition, caching the constructed worklow
607+
608+
Parameters
609+
----------
610+
definition : WorkflowDef
611+
The definition of the workflow to construct
612+
dont_cache : bool, optional
613+
Whether to cache the constructed workflow, by default False
614+
lazy : Sequence[str], optional
615+
The names of the inputs to the workflow to be considered lazy even if they
616+
have values in the given definition, by default ()
617+
"""
604618

605619
# Check the previously constructed workflows to see if a workflow has been
606620
# constructed for the given set of inputs, or a less-specific set (i.e. with a
607621
# super-set of lazy inputs), and use that if it exists
608622

609623
non_lazy_vals = {
610-
n: v for n, v in attrs_values(definition).items() if not is_lazy(v)
624+
n: v
625+
for n, v in attrs_values(definition).items()
626+
if not is_lazy(v) and n not in lazy
611627
}
612628
non_lazy_keys = frozenset(non_lazy_vals)
613629
hash_cache = Cache() # share the hash cache to avoid recalculations
@@ -821,56 +837,55 @@ def _create_graph(
821837
DiGraph
822838
The graph of the workflow
823839
"""
824-
graph: DiGraph = DiGraph()
840+
graph: DiGraph = DiGraph(name=self.name)
825841
for node in nodes:
826842
graph.add_nodes(node)
827843
# TODO: create connection is run twice
828844
for node in nodes:
829845
other_states = {}
830-
for field in attrs_fields(node.inputs):
846+
for field in list_fields(node._definition):
831847
lf = node._definition[field.name]
832848
if isinstance(lf, LazyOutField):
833849
# adding an edge to the graph if task id expecting output from a different task
834-
if lf._node.name != self.name:
835-
# checking if the connection is already in the graph
836-
if (graph.node(lf._node.name), node) not in graph.edges:
837-
graph.add_edges((graph.node(lf._node.name), node))
838-
if detailed:
839-
graph.add_edges_description(
840-
(node.name, field.name, lf._node.name, lf._field)
841-
)
842-
logger.debug("Connecting %s to %s", lf._node.name, node.name)
843-
# adding a state from the previous task to other_states
850+
851+
# checking if the connection is already in the graph
852+
if (graph.node(lf._node.name), node) not in graph.edges:
853+
graph.add_edges((graph.node(lf._node.name), node))
854+
if detailed:
855+
graph.add_edges_description(
856+
(node.name, field.name, lf._node.name, lf._field)
857+
)
858+
logger.debug("Connecting %s to %s", lf._node.name, node.name)
859+
# adding a state from the previous task to other_states
860+
if (
861+
graph.node(lf._node.name).state
862+
and graph.node(lf._node.name).state.splitter_rpn_final
863+
):
864+
# variables that are part of inner splitters should be
865+
# treated as a containers
844866
if (
845-
graph.node(lf._node.name).state
846-
and graph.node(lf._node.name).state.splitter_rpn_final
867+
node.state
868+
and f"{node.name}.{field.name}"
869+
in node.state._current_splitter_rpn
847870
):
848-
# variables that are part of inner splitters should be
849-
# treated as a containers
850-
if (
851-
node.state
852-
and f"{node.name}.{field.name}"
853-
in node.state._current_splitter_rpn
854-
):
855-
node.state._inner_cont_dim[
856-
f"{node.name}.{field.name}"
857-
] = 1
858-
# adding task_name: (task.state, [a field from the connection]
859-
if lf._node.name not in other_states:
860-
other_states[lf._node.name] = (
861-
graph.node(lf._node.name).state,
862-
[field.name],
863-
)
864-
else:
865-
# if the task already exist in other_state,
866-
# additional field name should be added to the list of fields
867-
other_states[lf._node.name][1].append(field.name)
868-
else: # LazyField with the wf input
869-
# connections with wf input should be added to the detailed graph description
870-
if detailed:
871-
graph.add_edges_description(
872-
(node.name, field.name, lf._node.name, lf._field)
871+
node.state._inner_cont_dim[f"{node.name}.{field.name}"] = 1
872+
# adding task_name: (task.state, [a field from the connection]
873+
if lf._node.name not in other_states:
874+
other_states[lf._node.name] = (
875+
graph.node(lf._node.name).state,
876+
[field.name],
873877
)
878+
else:
879+
# if the task already exist in other_state,
880+
# additional field name should be added to the list of fields
881+
other_states[lf._node.name][1].append(field.name)
882+
elif (
883+
isinstance(lf, LazyInField) and detailed
884+
): # LazyField with the wf input
885+
# connections with wf input should be added to the detailed graph description
886+
graph.add_edges_description(
887+
(node.name, field.name, INPUTS_NODE_NAME, lf._field)
888+
)
874889

875890
# if task has connections state has to be recalculated
876891
if other_states:
@@ -890,6 +905,12 @@ def _create_graph(
890905
other_states=other_states,
891906
combiner=combiner,
892907
)
908+
if detailed:
909+
lf: LazyOutField
910+
for outpt_name, lf in attrs_values(self.outputs).items():
911+
graph.add_edges_description(
912+
(OUTPUTS_NODE_NAME, outpt_name, lf._node.name, lf._field)
913+
)
893914
return graph
894915

895916

pydra/engine/graph.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
NodeType = ty.TypeVar("NodeType")
1212

13+
INPUTS_NODE_NAME = "__INPUTS__"
14+
OUTPUTS_NODE_NAME = "__OUTPUTS__"
15+
1316

1417
class DiGraph(ty.Generic[NodeType]):
1518
"""A simple Directed Graph object."""
@@ -395,7 +398,7 @@ def create_dotfile_simple(self, outdir, name="graph"):
395398

396399
dotstr = "digraph G {\n"
397400
for nd in self.nodes:
398-
if is_workflow(nd):
401+
if is_workflow(nd._definition):
399402
if nd.state:
400403
# adding color for wf with a state
401404
dotstr += f"{nd.name} [shape=box, color=blue]\n"
@@ -430,27 +433,29 @@ def create_dotfile_detailed(self, outdir, name="graph_det"):
430433
if not self._nodes_details:
431434
raise Exception("node_details is empty, detailed dotfile can't be created")
432435
for nd_nm, nd_det in self.nodes_details.items():
433-
if nd_nm == self.name: # the main workflow itself
436+
if nd_nm == INPUTS_NODE_NAME: # the main workflow itself
434437
# wf inputs
435438
wf_inputs_str = f'{{<{nd_det["outputs"][0]}> {nd_det["outputs"][0]}'
436439
for el in nd_det["outputs"][1:]:
437440
wf_inputs_str += f" | <{el}> {el}"
438441
wf_inputs_str += "}"
439-
dotstr += f'struct_{nd_nm} [color=red, label="{{WORKFLOW INPUT: | {wf_inputs_str}}}"];\n'
442+
dotstr += (
443+
f"struct_{self.name} [color=red, "
444+
f'label="{{WORKFLOW INPUT: | {wf_inputs_str}}}"];\n'
445+
)
446+
elif nd_nm == OUTPUTS_NODE_NAME:
440447
# wf outputs
441448
wf_outputs_str = f'{{<{nd_det["inputs"][0]}> {nd_det["inputs"][0]}'
442449
for el in nd_det["inputs"][1:]:
443450
wf_outputs_str += f" | <{el}> {el}"
444451
wf_outputs_str += "}"
445452
dotstr += (
446-
f"struct_{nd_nm}_out "
453+
f"struct_{self.name}_out "
447454
f'[color=red, label="{{WORKFLOW OUTPUT: | {wf_outputs_str}}}"];\n'
448455
)
449456
# connections to the wf outputs
450457
for con in nd_det["connections"]:
451-
dotstr += (
452-
f"struct_{con[1]}:{con[2]} -> struct_{nd_nm}_out:{con[0]};\n"
453-
)
458+
dotstr += f"struct_{con[1]}:{con[2]} -> struct_{self.name}_out:{con[0]};\n"
454459
else: # elements of the main workflow
455460
inputs_str = "{INPUT:"
456461
for inp in nd_det["inputs"]:
@@ -466,7 +471,11 @@ def create_dotfile_detailed(self, outdir, name="graph_det"):
466471
)
467472
# connections between elements
468473
for con in nd_det["connections"]:
469-
dotstr += f"struct_{con[1]}:{con[2]} -> struct_{nd_nm}:{con[0]};\n"
474+
in_conn = self.name if con[1] == INPUTS_NODE_NAME else con[1]
475+
out_conn = self.name if con[0] == OUTPUTS_NODE_NAME else con[0]
476+
dotstr += (
477+
f"struct_{in_conn}:{con[2]} -> struct_{nd_nm}:{out_conn};\n"
478+
)
470479
dotstr += "}"
471480
Path(outdir).mkdir(parents=True, exist_ok=True)
472481
dotfile = Path(outdir) / f"{name}.dot"
@@ -486,16 +495,17 @@ def create_dotfile_nested(self, outdir, name="graph"):
486495
def _create_dotfile_single_graph(self, nodes, edges):
487496
from .core import is_workflow
488497

489-
wf_asnd = []
498+
wf_asnd = {}
490499
dotstr = ""
491500
for nd in nodes:
492-
if is_workflow(nd):
493-
wf_asnd.append(nd.name)
494-
for task in nd.graph.nodes:
495-
nd.create_connections(task)
501+
if is_workflow(nd._definition):
502+
nd_graph = nd._definition.construct().graph()
503+
wf_asnd[nd.name] = nd_graph
504+
# for task in nd_graph.nodes:
505+
# nd.create_connections(task)
496506
dotstr += f"subgraph cluster_{nd.name} {{\n" f"label = {nd.name} \n"
497507
dotstr += self._create_dotfile_single_graph(
498-
nodes=nd.graph.nodes, edges=nd.graph.edges
508+
nodes=nd_graph.nodes, edges=nd_graph.edges
499509
)
500510
if nd.state:
501511
dotstr += "color=blue\n"
@@ -517,12 +527,14 @@ def _create_dotfile_single_graph(self, nodes, edges):
517527
f"lhead=cluster_{ed[1].name}]\n"
518528
)
519529
elif ed[0].name in wf_asnd:
520-
tail_nd = list(ed[0].nodes)[-1].name
530+
nd_nodes = wf_asnd[ed[0].name].nodes
531+
tail_nd = list(nd_nodes)[-1].name
521532
dotstr_edg += (
522533
f"{tail_nd} -> {ed[1].name} [ltail=cluster_{ed[0].name}]\n"
523534
)
524535
elif ed[1].name in wf_asnd:
525-
head_nd = list(ed[1].nodes)[0].name
536+
nd_nodes = wf_asnd[ed[1].name].nodes
537+
head_nd = list(nd_nodes)[0].name
526538
dotstr_edg += (
527539
f"{ed[0].name} -> {head_nd} [lhead=cluster_{ed[1].name}]\n"
528540
)

pydra/engine/helpers.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,35 +34,36 @@
3434
def plot_workflow(
3535
workflow_task: "WorkflowDef",
3636
out_dir: Path,
37-
type="simple",
38-
export=None,
39-
name=None,
40-
output_dir=None,
37+
plot_type: str = "simple",
38+
export: ty.Sequence[str] | None = None,
39+
name: str | None = None,
40+
output_dir: Path | None = None,
41+
lazy: ty.Sequence[str] | ty.Set[str] = (),
4142
):
4243
"""creating a graph - dotfile and optionally exporting to other formats"""
4344
from .core import Workflow
4445

4546
# Create output directory
4647
out_dir.mkdir(parents=True, exist_ok=True)
4748

48-
# Construct the workflow object
49-
wf = Workflow.construct(workflow_task)
49+
# Construct the workflow object with all of the fields lazy
50+
wf = Workflow.construct(workflow_task, lazy=lazy)
5051

5152
if not name:
5253
name = f"graph_{type(workflow_task).__name__}"
53-
if type == "simple":
54+
if plot_type == "simple":
5455
graph = wf.graph()
5556
dotfile = graph.create_dotfile_simple(outdir=out_dir, name=name)
56-
elif type == "nested":
57+
elif plot_type == "nested":
5758
graph = wf.graph()
5859
dotfile = graph.create_dotfile_nested(outdir=out_dir, name=name)
59-
elif type == "detailed":
60+
elif plot_type == "detailed":
6061
graph = wf.graph(detailed=True)
6162
dotfile = graph.create_dotfile_detailed(outdir=out_dir, name=name)
6263
else:
6364
raise Exception(
6465
f"type of the graph can be simple, detailed or nested, "
65-
f"but {type} provided"
66+
f"but {plot_type} provided"
6667
)
6768
if not export:
6869
return dotfile

pydra/engine/node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,11 +205,11 @@ def _set_state(self) -> None:
205205
if not_split := [
206206
c
207207
for c in combiner
208-
if not any(c in s for s in self.state.splitter_rpn)
208+
if not any(c in s for s in self.state.splitter_rpn) and "." not in c
209209
]:
210210
raise ValueError(
211211
f"Combiner fields {not_split} for Node {self.name!r} are not in the "
212-
f"splitter fields {self.state.splitter_rpn}"
212+
f"splitter {self.state.splitter_rpn}"
213213
)
214214
else:
215215
self._state = None

pydra/engine/state.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -855,13 +855,17 @@ def splitter_validation(self):
855855

856856
def combiner_validation(self):
857857
"""validating if the combiner is correct (after all states are connected)"""
858-
if self.combiner:
858+
if local_names := set(c for c in self.combiner if "." not in c):
859859
if not self.splitter:
860860
raise hlpst.PydraStateError(
861-
"splitter has to be set before setting combiner"
861+
"splitter has to be set before setting combiner with field names "
862+
f"in the current node {list(local_names)}"
863+
)
864+
if missing := local_names - set(self.splitter_rpn):
865+
raise hlpst.PydraStateError(
866+
"all field names from the current node referenced in the combiner, "
867+
f"{list(missing)} are missing in the splitter"
862868
)
863-
if set(self._combiner) - set(self.splitter_rpn):
864-
raise hlpst.PydraStateError("all combiners have to be in the splitter")
865869

866870
def prepare_states(
867871
self,

0 commit comments

Comments
 (0)