Skip to content

Commit da92150

Browse files
committed
dropped StateIndex object in favour of simple int index
1 parent f62fd0d commit da92150

File tree

4 files changed

+47
-178
lines changed

4 files changed

+47
-178
lines changed

pydra/engine/core.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Task(ty.Generic[DefType]):
8989
definition: DefType
9090
submitter: "Submitter | None"
9191
environment: "Environment | None"
92-
state_index: state.StateIndex
92+
state_index: int
9393
bindings: dict[str, ty.Any] | None = None # Bindings for the task environment
9494

9595
_inputs: dict[str, ty.Any] | None = None
@@ -100,7 +100,7 @@ def __init__(
100100
submitter: "Submitter",
101101
name: str,
102102
environment: "Environment | None" = None,
103-
state_index: "state.StateIndex | None" = None,
103+
state_index: int | None = None,
104104
hooks: TaskHooks | None = None,
105105
):
106106
"""
@@ -120,9 +120,6 @@ def __init__(
120120
4. Two or more concurrent new processes get to start
121121
"""
122122

123-
if state_index is None:
124-
state_index = state.StateIndex()
125-
126123
if not isinstance(definition, TaskDef):
127124
raise ValueError(
128125
f"Task definition ({definition!r}) must be a TaskDef, not {type(definition)}"

pydra/engine/lazy.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .submitter import DiGraph, NodeExecution
1010
from .core import Task, Workflow
1111
from .specs import TaskDef
12-
from .state import StateIndex
1312

1413

1514
T = ty.TypeVar("T")
@@ -49,7 +48,7 @@ def _get_value(
4948
self,
5049
workflow: "Workflow",
5150
graph: "DiGraph[NodeExecution]",
52-
state_index: "StateIndex | None" = None,
51+
state_index: int | None = None,
5352
) -> ty.Any:
5453
"""Return the value of a lazy field.
5554
@@ -59,7 +58,7 @@ def _get_value(
5958
the workflow object
6059
graph: DiGraph[NodeExecution]
6160
the graph representing the execution state of the workflow
62-
state_index : StateIndex, optional
61+
state_index : int, optional
6362
the state index of the field to access
6463
6564
Returns
@@ -95,7 +94,7 @@ def _get_value(
9594
self,
9695
workflow: "Workflow",
9796
graph: "DiGraph[NodeExecution]",
98-
state_index: "StateIndex | None" = None,
97+
state_index: int | None = None,
9998
) -> ty.Any:
10099
"""Return the value of a lazy field.
101100
@@ -105,7 +104,7 @@ def _get_value(
105104
the workflow object
106105
graph: DiGraph[NodeExecution]
107106
the graph representing the execution state of the workflow
108-
state_index : StateIndex, optional
107+
state_index : int, optional
109108
the state index of the field to access
110109
111110
Returns
@@ -134,7 +133,7 @@ def _get_value(
134133
self,
135134
workflow: "Workflow",
136135
graph: "DiGraph[NodeExecution]",
137-
state_index: "StateIndex | None" = None,
136+
state_index: int | None = None,
138137
) -> ty.Any:
139138
"""Return the value of a lazy field.
140139
@@ -144,18 +143,14 @@ def _get_value(
144143
the workflow object
145144
graph: DiGraph[NodeExecution]
146145
the graph representing the execution state of the workflow
147-
state_index : StateIndex, optional
146+
state_index : int, optional
148147
the state index of the field to access
149148
150149
Returns
151150
-------
152151
value : Any
153152
the resolved value of the lazy-field
154153
"""
155-
from pydra.engine.state import StateIndex
156-
157-
if state_index is None:
158-
state_index = StateIndex()
159154

160155
jobs = graph.node(self._node.name).matching_jobs(state_index)
161156

pydra/engine/state.py

Lines changed: 6 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from copy import deepcopy
44
import itertools
5-
from collections import OrderedDict
65
from functools import reduce
76
import typing as ty
87
from . import helpers_state as hlpst
@@ -15,117 +14,6 @@
1514
OutputsType = ty.TypeVar("OutputsType")
1615

1716

18-
class StateIndex:
19-
"""The collection of state indices that identifies a single element within the list
20-
of tasks generated from a node
21-
22-
Parameters
23-
----------
24-
indices : dict[str, int]
25-
a dictionary of indices for each input field
26-
"""
27-
28-
indices: OrderedDict[str, int]
29-
30-
def __init__(
31-
self, indices: dict[str, int] | ty.Sequence[tuple[str, int]] | None = None
32-
):
33-
# We used ordered dict here to ensure the keys are always in the same order
34-
# while OrderedDict is not strictly necessary for CPython 3.7+, we use it to
35-
# signal that the order of the keys is important
36-
if indices is None:
37-
self.indices = OrderedDict()
38-
else:
39-
if isinstance(indices, dict):
40-
indices = indices.items()
41-
self.indices = OrderedDict(sorted(indices))
42-
43-
def __len__(self) -> int:
44-
return len(self.indices)
45-
46-
def __iter__(self) -> ty.Generator[str, None, None]:
47-
return iter(self.indices)
48-
49-
def __getitem__(self, key: str) -> int:
50-
return self.indices[key]
51-
52-
def __lt__(self, other: "StateIndex") -> bool:
53-
if list(self.indices) != list(other.indices):
54-
raise ValueError(
55-
f"StateIndex {self} does not contain the same indices in the same order "
56-
f"as {other}: {list(self.indices)} != {list(other.indices)}"
57-
)
58-
return tuple(self.indices.items()) < tuple(other.indices.items())
59-
60-
def __repr__(self) -> str:
61-
return (
62-
"StateIndex(" + ", ".join(f"{n}={v}" for n, v in self.indices.items()) + ")"
63-
)
64-
65-
def __hash__(self):
66-
return hash(tuple(self.indices.items()))
67-
68-
def __eq__(self, other) -> bool:
69-
return self.indices == other.indices
70-
71-
def __str__(self) -> str:
72-
return "__".join(f"{n}-{i}" for n, i in self.indices.items())
73-
74-
def __bool__(self) -> bool:
75-
return bool(self.indices)
76-
77-
def subset(self, state_names: ty.Iterable[str]) -> ty.Self:
78-
"""Create a new StateIndex with only the specified fields
79-
80-
Parameters
81-
----------
82-
fields : list[str]
83-
the fields to keep in the new StateIndex
84-
85-
Returns
86-
-------
87-
StateIndex
88-
a new StateIndex with only the specified fields
89-
"""
90-
return type(self)({k: v for k, v in self.indices.items() if k in state_names})
91-
92-
def missing(self, state_names: ty.Iterable[str]) -> ty.List[str]:
93-
"""Return the fields that are missing from the StateIndex
94-
95-
Parameters
96-
----------
97-
fields : list[str]
98-
the fields to check for
99-
100-
Returns
101-
-------
102-
list[str]
103-
the fields that are missing from the StateIndex
104-
"""
105-
return [f for f in state_names if f not in self.indices]
106-
107-
def matches(self, other: "StateIndex") -> bool:
108-
"""Check if the indices that are present in the other StateIndex match
109-
110-
Parameters
111-
----------
112-
other : StateIndex
113-
the other StateIndex to compare against
114-
115-
Returns
116-
-------
117-
bool
118-
True if all the indices in the other StateIndex match
119-
"""
120-
if isinstance(other, dict):
121-
other = StateIndex(other)
122-
if not set(self.indices).issuperset(other.indices):
123-
raise ValueError(
124-
f"StateIndex {self} does not contain all the indices in {other}"
125-
)
126-
return all(self.indices[k] == v for k, v in other.indices.items())
127-
128-
12917
class State:
13018
"""
13119
A class that specifies a State of all tasks.
@@ -1314,7 +1202,7 @@ def _single_op_splits(self, op_single):
13141202
keys = [op_single]
13151203
return val, keys
13161204

1317-
def _get_element(self, value: ty.Any, field_name: str, ind: int):
1205+
def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
13181206
"""
13191207
Extracting element of the inputs taking into account
13201208
container dimension of the specific element that can be set in self.state.cont_dim.
@@ -1329,6 +1217,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int):
13291217
name of the input field
13301218
ind : int
13311219
index of the element
1220+
1221+
Returns
1222+
-------
1223+
Any
1224+
specific element of the input field
13321225
"""
13331226
if f"{self.name}.{field_name}" in self.cont_dim:
13341227
return list(

0 commit comments

Comments
 (0)