Skip to content

Commit 32931d3

Browse files
committed
debugging state preparation and lazy value resolution
1 parent d8321c0 commit 32931d3

File tree

3 files changed

+299
-266
lines changed

3 files changed

+299
-266
lines changed

pydra/engine/lazy.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,6 @@ def _get_value(
176176
value : Any
177177
the resolved value of the lazy-field
178178
"""
179-
state = self._node.state
180-
jobs = graph.node(self._node.name).get_jobs(state_index)
181179

182180
def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
183181
if job.errored:
@@ -209,12 +207,42 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
209207
val = self._apply_cast(val)
210208
return val
211209

212-
if not isinstance(jobs, StateArray): # single job
213-
return retrieve_from_job(jobs)
214-
elif not state or not state.depth(before_combine=True):
215-
assert len(jobs) == 1
216-
return retrieve_from_job(jobs[0])
217-
return [retrieve_from_job(j) for j in jobs]
210+
# Get the execution node that the value is coming from
211+
upstream_node = graph.node(self._node.name)
212+
213+
if not upstream_node._tasks: # No jobs, return empty state array
214+
return StateArray()
215+
if not upstream_node.state: # Return the singular job
216+
value = retrieve_from_job(upstream_node._tasks[None])
217+
if state_index is not None:
218+
return value[state_index]
219+
return value
220+
if upstream_node.state.combiner:
221+
222+
# No state remains after the combination, return all values in a list
223+
if not upstream_node.state.ind_l_final:
224+
return [retrieve_from_job(j) for j in upstream_node.tasks]
225+
226+
# Group the values of the tasks into list before returning
227+
def group_values(index: int) -> list:
228+
# Get a slice of the tasks that match the given index of the state array of the
229+
# combined values
230+
final_index = set(upstream_node.state.states_ind_final[index].items())
231+
return [
232+
retrieve_from_job(upstream_node._tasks[i])
233+
for i, ind in enumerate(upstream_node.state.states_ind)
234+
if set(ind.items()).issuperset(final_index)
235+
]
236+
237+
if state_index is None: # return all groups if no index is given
238+
return StateArray(
239+
group_values(i) for i in range(len(upstream_node.state.ind_l_final))
240+
)
241+
return group_values(state_index) # select the group that matches the index
242+
if state_index is None: # return all jobs in a state array
243+
return StateArray(retrieve_from_job(j) for j in upstream_node.tasks)
244+
# Select the job that matches the index
245+
return retrieve_from_job(upstream_node._tasks[state_index])
218246

219247
@property
220248
def _source(self):

pydra/engine/submitter.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
attrs_values,
1919
)
2020
from pydra.utils.hash import PersistentCache
21-
from pydra.utils.typing import StateArray
2221
from pydra.engine.lazy import LazyField
2322
from .audit import Audit
2423
from .core import Task
@@ -536,7 +535,6 @@ def __init__(
536535
self.queued = {}
537536
self.running = {} # Not used in logic, but may be useful for progress tracking
538537
self.unrunnable = defaultdict(list)
539-
self.state_names = self.node.state.names if self.node.state else []
540538
self.workflow = workflow
541539
self.graph = None
542540

@@ -566,47 +564,19 @@ def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
566564
raise RuntimeError("Tasks have not been generated")
567565
return self._tasks.values()
568566

569-
def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
570-
"""Get the jobs that match a given state index.
571-
572-
Parameters
573-
----------
574-
final_index : int, optional
575-
The index of the output state array (i.e. after any combinations) of the
576-
job to get, by default None
577-
578-
Returns
579-
-------
580-
matching : Task | StateArray[Task]
581-
The task or tasks that match the given index
582-
"""
583-
if not self.tasks: # No jobs, return empty state array
584-
return StateArray()
585-
if not self.node.state: # Return the singular job
586-
return self._tasks[None]
587-
if final_index is None: # return all jobs in a state array
588-
return StateArray(self._tasks.values())
589-
if not self.node.state.combiner: # Select the job that matches the index
590-
return self._tasks[final_index]
591-
# Get a slice of the tasks that match the given index of the state array of the
592-
# combined values
593-
final_index = set(self.node.state.states_ind_final[final_index].items())
594-
return StateArray(
595-
self._tasks[i]
596-
for i, ind in enumerate(self.node.state.states_ind)
597-
if set(ind.items()).issuperset(final_index)
598-
)
599-
600567
def start(self) -> None:
601568
"""Prepare the execution node so that it can be processed"""
602569
self._tasks = {}
603570
if self.state:
604571
values = {}
605572
for name, value in self.node.state_values.items():
606-
if name in self.node.state.names and isinstance(value, LazyField):
607-
values[name] = value._get_value(
608-
workflow=self.workflow, graph=self.graph
609-
)
573+
if name in self.node.state.names:
574+
if isinstance(value, LazyField):
575+
values[name] = value._get_value(
576+
workflow=self.workflow, graph=self.graph
577+
)
578+
else:
579+
values[name] = value
610580
self.state.prepare_states(values)
611581
self.state.prepare_inputs()
612582
# Generate the tasks

0 commit comments

Comments
 (0)