Skip to content

Commit d8321c0

Browse files
committed
attempting to debug state splitting/combining logic
1 parent ce860d2 commit d8321c0

File tree

3 files changed

+62
-101
lines changed

3 files changed

+62
-101
lines changed

pydra/engine/lazy.py

Lines changed: 21 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import typing as ty
22
import abc
33
import attrs
4+
from typing import Self
45
from pydra.utils.typing import StateArray
56
from pydra.utils.hash import hash_single
67
from . import node
@@ -68,11 +69,29 @@ def _get_value(
6869
"""
6970
raise NotImplementedError("LazyField is an abstract class")
7071

71-
def split(self) -> "LazyField":
72+
def split(self) -> Self:
7273
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
7374
of the lazy field with StateArray to signify that it will be "split" across
7475
"""
75-
raise NotImplementedError("LazyField is an abstract class")
76+
from ..utils.typing import TypeParser # pylint: disable=import-outside-toplevel
77+
78+
# Modify the type of the lazy field to include the split across a state-array
79+
inner_type, prev_split_depth = TypeParser.strip_splits(self._type)
80+
assert prev_split_depth <= 1
81+
if inner_type is ty.Any:
82+
type_ = StateArray[ty.Any]
83+
elif TypeParser.matches_type(inner_type, list):
84+
item_type = TypeParser.get_item_type(inner_type)
85+
type_ = StateArray[item_type]
86+
else:
87+
raise TypeError(
88+
f"Cannot split non-sequence field {self} of type {inner_type}"
89+
)
90+
if prev_split_depth:
91+
type_ = StateArray[
92+
type_
93+
] # FIXME: This nesting of StateArray is probably unnecessary
94+
return attrs.evolve(self, type=type_)
7695

7796

7897
@attrs.define(kw_only=True)
@@ -122,25 +141,6 @@ def _get_value(
122141
value = self._apply_cast(value)
123142
return value
124143

125-
def split(self) -> "LazyField":
126-
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
127-
of the lazy field with StateArray to signify that it will be "split" across
128-
"""
129-
from ..utils.typing import TypeParser # pylint: disable=import-outside-toplevel
130-
131-
assert not isinstance(self, LazyInField)
132-
133-
if not TypeParser.matches_type(self.type, list):
134-
raise TypeError(
135-
f"Cannot split non-sequence field {self} of type {self.type}"
136-
)
137-
138-
return type(self)(
139-
name=self.name,
140-
field=self.field,
141-
type=StateArray[TypeParser.get_item_type(self.type)],
142-
)
143-
144144

145145
@attrs.define(kw_only=True)
146146
class LazyOutField(LazyField[T]):
@@ -214,51 +214,8 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
214214
elif not state or not state.depth(before_combine=True):
215215
assert len(jobs) == 1
216216
return retrieve_from_job(jobs[0])
217-
# elif state.combiner and state.keys_final:
218-
# # We initialise it here rather than using a defaultdict to ensure the order
219-
# # of the keys matches how it is defined in the state so we can return the
220-
# # values in the correct order
221-
# sorted_values = {frozenset(i.items()): [] for i in state.states_ind_final}
222-
# # Iterate through the jobs and append the values to the correct final state
223-
# # key
224-
# for job in jobs:
225-
# state_key = frozenset(
226-
# (key, state.states_ind[job.state_index][key])
227-
# for key in state.keys_final
228-
# )
229-
# sorted_values[state_key].append(retrieve_from_job(job))
230-
# return StateArray(sorted_values.values())
231-
# else:
232217
return [retrieve_from_job(j) for j in jobs]
233218

234-
def split(self) -> "LazyField":
235-
""" "Splits" the lazy field over an array of nodes by replacing the sequence type
236-
of the lazy field with StateArray to signify that it will be "split" across
237-
"""
238-
from ..utils.typing import TypeParser # pylint: disable=import-outside-toplevel
239-
240-
# Modify the type of the lazy field to include the split across a state-array
241-
inner_type, prev_split_depth = TypeParser.strip_splits(self.type)
242-
assert prev_split_depth <= 1
243-
if inner_type is ty.Any:
244-
type_ = StateArray[ty.Any]
245-
elif TypeParser.matches_type(inner_type, list):
246-
item_type = TypeParser.get_item_type(inner_type)
247-
type_ = StateArray[item_type]
248-
else:
249-
raise TypeError(
250-
f"Cannot split non-sequence field {self} of type {inner_type}"
251-
)
252-
if prev_split_depth:
253-
type_ = StateArray[
254-
type_
255-
] # FIXME: This nesting of StateArray is probably unnecessary
256-
return type(self)[type_](
257-
name=self.name,
258-
field=self.field,
259-
type=type_,
260-
)
261-
262219
@property
263220
def _source(self):
264221
return self._node

pydra/engine/specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def split(
345345
split_inputs = {}
346346
for name, value in inputs.items():
347347
if isinstance(value, lazy.LazyField):
348-
split_val = value.split(splitter)
348+
split_val = value.split()
349349
elif isinstance(value, ty.Iterable) and not isinstance(
350350
value, (ty.Mapping, str)
351351
):

pydra/engine/submitter.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -536,17 +536,14 @@ def __init__(
536536
self.queued = {}
537537
self.running = {} # Not used in logic, but may be useful for progress tracking
538538
self.unrunnable = defaultdict(list)
539-
# Prepare the state to be run
540-
if node.state:
541-
self.state = node.state
542-
self.state.prepare_states(self.node.state_values)
543-
self.state.prepare_inputs()
544-
else:
545-
self.state = None
546539
self.state_names = self.node.state.names if self.node.state else []
547540
self.workflow = workflow
548541
self.graph = None
549542

543+
@property
544+
def state(self):
545+
return self.node.state
546+
550547
def __repr__(self):
551548
return (
552549
f"NodeExecution(name={self.name!r}, blocked={list(self.blocked)}, "
@@ -566,7 +563,7 @@ def _definition(self) -> "Node":
566563
@property
567564
def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
568565
if self._tasks is None:
569-
self._tasks = {t.state_index: t for t in self._generate_tasks()}
566+
raise RuntimeError("Tasks have not been generated")
570567
return self._tasks.values()
571568

572569
def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
@@ -586,14 +583,11 @@ def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
586583
if not self.tasks: # No jobs, return empty state array
587584
return StateArray()
588585
if not self.node.state: # Return the singular job
589-
assert final_index is None
590-
task = self._tasks[None]
591-
return task
586+
return self._tasks[None]
592587
if final_index is None: # return all jobs in a state array
593588
return StateArray(self._tasks.values())
594589
if not self.node.state.combiner: # Select the job that matches the index
595-
task = self._tasks[final_index]
596-
return task
590+
return self._tasks[final_index]
597591
# Get a slice of the tasks that match the given index of the state array of the
598592
# combined values
599593
final_index = set(self.node.state.states_ind_final[final_index].items())
@@ -603,6 +597,38 @@ def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]":
603597
if set(ind.items()).issuperset(final_index)
604598
)
605599

600+
def start(self) -> None:
601+
"""Prepare the execution node so that it can be processed"""
602+
self._tasks = {}
603+
if self.state:
604+
values = {}
605+
for name, value in self.node.state_values.items():
606+
if name in self.node.state.names and isinstance(value, LazyField):
607+
values[name] = value._get_value(
608+
workflow=self.workflow, graph=self.graph
609+
)
610+
self.state.prepare_states(values)
611+
self.state.prepare_inputs()
612+
# Generate the tasks
613+
for index, split_defn in enumerate(self._split_definition()):
614+
self._tasks[index] = Task(
615+
definition=split_defn,
616+
submitter=self.submitter,
617+
environment=self.node._environment,
618+
name=self.node.name,
619+
hooks=self.node._hooks,
620+
state_index=index,
621+
)
622+
else:
623+
self._tasks[None] = Task(
624+
definition=self._resolve_lazy_inputs(task_def=self.node._definition),
625+
submitter=self.submitter,
626+
environment=self.node._environment,
627+
hooks=self.node._hooks,
628+
name=self.node.name,
629+
)
630+
self.blocked = copy(self._tasks)
631+
606632
@property
607633
def started(self) -> bool:
608634
return (
@@ -656,26 +682,6 @@ def all_failed(self) -> bool:
656682
self.successful or self.blocked or self.queued
657683
)
658684

659-
def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
660-
if not self.node.state:
661-
yield Task(
662-
definition=self._resolve_lazy_inputs(task_def=self.node._definition),
663-
submitter=self.submitter,
664-
environment=self.node._environment,
665-
hooks=self.node._hooks,
666-
name=self.node.name,
667-
)
668-
else:
669-
for index, split_defn in enumerate(self._split_definition()):
670-
yield Task(
671-
definition=split_defn,
672-
submitter=self.submitter,
673-
environment=self.node._environment,
674-
name=self.node.name,
675-
hooks=self.node._hooks,
676-
state_index=index,
677-
)
678-
679685
def _resolve_lazy_inputs(
680686
self,
681687
task_def: "TaskDef",
@@ -754,10 +760,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
754760
List of tasks that are ready to run
755761
"""
756762
runnable: list["Task[DefType]"] = []
757-
self.tasks # Ensure tasks are loaded
758763
if not self.started:
759-
assert self._tasks is not None
760-
self.blocked = copy(self._tasks)
764+
self.start()
761765
# Check to see if any blocked tasks are now runnable/unrunnable
762766
for index, task in list(self.blocked.items()):
763767
pred: NodeExecution

0 commit comments

Comments
 (0)