Skip to content

Commit b2034d5

Browse files
committed
moved resolved lazy inputs into NodeExecution class from TaskDef
1 parent f7021e5 commit b2034d5

File tree

7 files changed

+148
-170
lines changed

7 files changed

+148
-170
lines changed

pydra/engine/lazy.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from . import node
77

88
if ty.TYPE_CHECKING:
9-
from .graph import DiGraph
109
from .submitter import NodeExecution
1110
from .core import Task, Workflow
12-
from .specs import TaskDef, WorkflowDef
11+
from .specs import TaskDef
1312
from .state import StateIndex
1413

1514

@@ -46,6 +45,27 @@ def _apply_cast(self, value):
4645
value = self._type(value)
4746
return value
4847

48+
def _get_value(
49+
self,
50+
node_exec: "NodeExecution",
51+
state_index: "StateIndex | None" = None,
52+
) -> ty.Any:
53+
"""Return the value of a lazy field.
54+
55+
Parameters
56+
----------
57+
node_exec: NodeExecution
58+
the object representing the execution state of the current node
59+
state_index : StateIndex, optional
60+
the state index of the field to access
61+
62+
Returns
63+
-------
64+
value : Any
65+
the resolved value of the lazy-field
66+
"""
67+
raise NotImplementedError("LazyField is an abstract class")
68+
4969

5070
@attrs.define(kw_only=True)
5171
class LazyInField(LazyField[T]):
@@ -70,23 +90,25 @@ def _source(self):
7090

7191
def _get_value(
7292
self,
73-
workflow_def: "WorkflowDef",
93+
node_exec: "NodeExecution",
94+
state_index: "StateIndex | None" = None,
7495
) -> ty.Any:
7596
"""Return the value of a lazy field.
7697
7798
Parameters
7899
----------
79-
wf : Workflow
80-
the workflow the lazy field references
81-
state_index : int, optional
82-
the state index of the field to access
100+
node_exec: NodeExecution
101+
the object representing the execution state of the current node
102+
state_index : StateIndex, optional
103+
the state index of the field to access (ignored, used for duck-typing with
104+
LazyOutField)
83105
84106
Returns
85107
-------
86108
value : Any
87109
the resolved value of the lazy-field
88110
"""
89-
value = workflow_def[self._field]
111+
value = node_exec.workflow_inputs[self._field]
90112
value = self._apply_cast(value)
91113
return value
92114

@@ -105,16 +127,16 @@ def __repr__(self):
105127

106128
def _get_value(
107129
self,
108-
graph: "DiGraph[NodeExecution]",
130+
node_exec: "NodeExecution",
109131
state_index: "StateIndex | None" = None,
110132
) -> ty.Any:
111133
"""Return the value of a lazy field.
112134
113135
Parameters
114136
----------
115-
wf : Workflow
116-
the workflow the lazy field references
117-
state_index : int, optional
137+
node_exec: NodeExecution
138+
the object representing the execution state of the current node
139+
state_index : StateIndex, optional
118140
the state index of the field to access
119141
120142
Returns
@@ -130,7 +152,7 @@ def _get_value(
130152
if state_index is None:
131153
state_index = StateIndex()
132154

133-
task = graph.node(self._node.name).task(state_index)
155+
task = node_exec.graph.node(self._node.name).task(state_index)
134156
_, split_depth = TypeParser.strip_splits(self._type)
135157

136158
def get_nested(task: "Task[DefType]", depth: int):

pydra/engine/node.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,28 @@ def lzout(self) -> OutputType:
121121
type=field.type,
122122
)
123123
outputs = self.inputs.Outputs(**lazy_fields)
124-
# Flag the output lazy fields as being not typed checked (i.e. assigned to another
125-
# node's inputs) yet
124+
126125
outpt: lazy.LazyOutField
127126
for outpt in attrs_values(outputs).values():
128-
outpt._type_checked = False
127+
# Assign the current node to the lazy fields so they can access the state
129128
outpt._node = self
129+
# If the node has a non-empty state, wrap the type of the lazy field in
130+
# a combination of an optional list and a number of nested StateArrays
131+
# types based on the number of states the node is split over and whether
132+
# it has a combiner
133+
if self._state:
134+
type_, _ = TypeParser.strip_splits(outpt._type)
135+
if self._state.combiner:
136+
type_ = list[type_]
137+
for _ in range(self._state.depth - int(bool(self._state.combiner))):
138+
type_ = StateArray[type_]
139+
outpt._type = type_
140+
# Flag the output lazy fields as being not typed checked (i.e. assigned to
141+
# another node's inputs) yet. This is used to prevent the user from changing
142+
# the type of the output after it has been accessed by connecting it to an
143+
# output of an upstream node with additional state variables.
144+
outpt._type_checked = False
130145
self._lzout = outputs
131-
self._wrap_lzout_types_in_state_arrays()
132146
return outputs
133147

134148
@property
@@ -217,20 +231,6 @@ def _check_if_outputs_have_been_used(self, msg):
217231
+ msg
218232
)
219233

220-
def _wrap_lzout_types_in_state_arrays(self) -> None:
221-
"""Wraps a types of the lazy out fields in a number of nested StateArray types
222-
based on the number of states the node is split over"""
223-
# Unwrap StateArray types from the output types
224-
if not self.state:
225-
return
226-
outpt_lf: lazy.LazyOutField
227-
for outpt_lf in attrs_values(self.lzout).values():
228-
assert not outpt_lf._type_checked
229-
type_, _ = TypeParser.strip_splits(outpt_lf._type)
230-
for _ in range(self._state.depth):
231-
type_ = StateArray[type_]
232-
outpt_lf._type = type_
233-
234234
def _set_state(self) -> None:
235235
# Add node name to state's splitter, combiner and cont_dim loaded from the def
236236
splitter = self._definition._splitter

pydra/engine/specs.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,12 @@
3535
from pydra.utils.typing import StateArray, MultiInputObj
3636
from pydra.design.base import Field, Arg, Out, RequirementSet, NO_DEFAULT
3737
from pydra.design import shell
38-
from pydra.engine.lazy import LazyInField, LazyOutField
3938

4039
if ty.TYPE_CHECKING:
4140
from pydra.engine.core import Task
4241
from pydra.engine.graph import DiGraph
4342
from pydra.engine.submitter import NodeExecution
4443
from pydra.engine.core import Workflow
45-
from pydra.engine.state import StateIndex
4644
from pydra.engine.environments import Environment
4745
from pydra.engine.workers import Worker
4846

@@ -476,41 +474,6 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
476474
}
477475
return hash_function(sorted(field_hashes.items())), field_hashes
478476

479-
def _resolve_lazy_inputs(
480-
self,
481-
workflow_inputs: "WorkflowDef",
482-
graph: "DiGraph[NodeExecution]",
483-
state_index: "StateIndex | None" = None,
484-
) -> Self:
485-
"""Resolves lazy fields in the task definition by replacing them with their
486-
actual values.
487-
488-
Parameters
489-
----------
490-
workflow : Workflow
491-
The workflow the task is part of
492-
graph : DiGraph[NodeExecution]
493-
The execution graph of the workflow
494-
state_index : StateIndex, optional
495-
The state index for the workflow, by default None
496-
497-
Returns
498-
-------
499-
Self
500-
The task definition with all lazy fields resolved
501-
"""
502-
from pydra.engine.state import StateIndex
503-
504-
if state_index is None:
505-
state_index = StateIndex()
506-
resolved = {}
507-
for name, value in attrs_values(self).items():
508-
if isinstance(value, LazyInField):
509-
resolved[name] = value._get_value(workflow_inputs)
510-
elif isinstance(value, LazyOutField):
511-
resolved[name] = value._get_value(graph, state_index)
512-
return attrs.evolve(self, **resolved)
513-
514477
def _check_rules(self):
515478
"""Check if all rules are satisfied."""
516479

pydra/engine/submitter.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
from copy import copy
1010
from datetime import datetime
1111
from collections import defaultdict
12+
import attrs
1213
from .workers import Worker, WORKERS
1314
from .graph import DiGraph
1415
from .helpers import (
1516
get_open_loop,
1617
list_fields,
18+
attrs_values,
1719
)
1820
from pydra.utils.hash import PersistentCache
1921
from .state import StateIndex
2022
from pydra.utils.typing import StateArray
23+
from pydra.engine.lazy import LazyField
2124
from .audit import Audit
2225
from .core import Task
2326
from pydra.utils.messenger import AuditFlag, Messenger
@@ -607,10 +610,7 @@ def all_failed(self) -> bool:
607610
def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
608611
if not self.node.state:
609612
yield Task(
610-
definition=self.node._definition._resolve_lazy_inputs(
611-
workflow_inputs=self.workflow_inputs,
612-
graph=self.graph,
613-
),
613+
definition=self._resolve_lazy_inputs(task_def=self.node._definition),
614614
submitter=self.submitter,
615615
environment=self.node._environment,
616616
hooks=self.node._hooks,
@@ -619,9 +619,8 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
619619
else:
620620
for index, split_defn in self.node._split_definition().items():
621621
yield Task(
622-
definition=split_defn._resolve_lazy_inputs(
623-
workflow_inputs=self.workflow_inputs,
624-
graph=self.graph,
622+
definition=self._resolve_lazy_inputs(
623+
task_def=split_defn,
625624
state_index=index,
626625
),
627626
submitter=self.submitter,
@@ -631,6 +630,32 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
631630
state_index=index,
632631
)
633632

633+
def _resolve_lazy_inputs(
634+
self,
635+
task_def: "TaskDef",
636+
state_index: "StateIndex | None" = None,
637+
) -> "TaskDef":
638+
"""Resolves lazy fields in the task definition by replacing them with their
639+
actual values calculated by upstream jobs.
640+
641+
Parameters
642+
----------
643+
task_def : TaskDef
644+
The definition to resolve the lazy fields of
645+
state_index : StateIndex, optional
646+
The state index for the workflow, by default None
647+
648+
Returns
649+
-------
650+
TaskDef
651+
The task definition with all lazy fields resolved
652+
"""
653+
resolved = {}
654+
for name, value in attrs_values(self).items():
655+
if isinstance(value, LazyField):
656+
resolved[name] = value._get_value(self, state_index)
657+
return attrs.evolve(self, **resolved)
658+
634659
def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
635660
"""For a given node, check to see which tasks have been successfully run, are ready
636661
to run, can't be run due to upstream errors, or are blocked on other tasks to complete.

0 commit comments

Comments
 (0)