Skip to content

Commit 89a474a

Browse files
committed
renamed cont_dim to container_ndim
1 parent c538929 commit 89a474a

File tree

8 files changed

+85
-78
lines changed

8 files changed

+85
-78
lines changed

pydra/compose/base/task.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class Task(ty.Generic[OutputsType]):
160160
# The following fields are used to store split/combine state information
161161
_splitter = attrs.field(default=None, init=False, repr=False)
162162
_combiner = attrs.field(default=None, init=False, repr=False)
163-
_cont_dim = attrs.field(default=None, init=False, repr=False)
163+
_container_ndim = attrs.field(default=None, init=False, repr=False)
164164
_hashes = attrs.field(default=None, init=False, eq=False, repr=False)
165165

166166
RESERVED_FIELD_NAMES = ("split", "combine")
@@ -265,7 +265,7 @@ def split(
265265
splitter: ty.Union[str, ty.List[str], ty.Tuple[str, ...], None] = None,
266266
/,
267267
overwrite: bool = False,
268-
cont_dim: ty.Optional[dict] = None,
268+
container_ndim: ty.Optional[dict] = None,
269269
**inputs,
270270
) -> Self:
271271
"""
@@ -279,9 +279,9 @@ def split(
279279
then the fields to split are taken from the keyword-arg names.
280280
overwrite : bool, optional
281281
whether to overwrite an existing split on the node, by default False
282-
cont_dim : dict, optional
282+
container_ndim : dict, optional
283283
Container dimensions for specific inputs, used in the splitter.
284-
If input name is not in cont_dim, it is assumed that the input values has
284+
If input name is not in container_ndim, it is assumed that the input values has
285285
a container dimension of 1, so only the most outer dim will be used for splitting.
286286
**inputs
287287
fields to split over, will be automatically wrapped in a StateArray object
@@ -321,7 +321,7 @@ def split(
321321
else:
322322
# If no splitter is provided, use the names of the inputs as combinatorial splitter
323323
split_names = splitter = list(inputs)
324-
for field_name in cont_dim or []:
324+
for field_name in container_ndim or []:
325325
if field_name not in split_names:
326326
raise ValueError(
327327
f"Container dimension for {field_name} is provided but the field "
@@ -342,7 +342,7 @@ def split(
342342
split_inputs[name] = split_val
343343
split_def = attrs.evolve(self, **split_inputs)
344344
split_def._splitter = splitter
345-
split_def._cont_dim = cont_dim
345+
split_def._container_ndim = container_ndim
346346
return split_def
347347

348348
def combine(
@@ -613,8 +613,8 @@ def bytes_repr_task(obj: Task, cache: Cache) -> ty.Iterator[bytes]:
613613
yield hash_single(obj._splitter, cache)
614614
yield b",_combiner="
615615
yield hash_single(obj._combiner, cache)
616-
yield b",_cont_dim="
617-
yield hash_single(obj._cont_dim, cache)
616+
yield b",_container_ndim="
617+
yield hash_single(obj._container_ndim, cache)
618618
yield b",_xor="
619619
yield hash_single(obj._xor, cache)
620620
yield b")"

pydra/compose/tests/test_workflow_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,7 @@ def test_wf_ndstinner_5(worker: str, tmp_path: Path):
15241524
"""workflow with 3 tasks,
15251525
the second task has two inputs and inner splitter from one of the input,
15261526
(inner input come from the first task that has its own splitter,
1527-
there is a inner_cont_dim)
1527+
there is a inner_container_ndim)
15281528
the third task has no new splitter
15291529
"""
15301530

pydra/engine/node.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,25 +173,25 @@ def _check_if_outputs_have_been_used(self, msg):
173173
)
174174

175175
def _set_state(self) -> None:
176-
# Add node name to state's splitter, combiner and cont_dim loaded from the def
176+
# Add node name to state's splitter, combiner and container_ndim loaded from the def
177177
splitter = deepcopy(self._task._splitter) # these can be modified in state
178178
combiner = deepcopy(self._task._combiner) # these can be modified in state
179-
cont_dim = {}
179+
container_ndim = {}
180180
if splitter:
181181
splitter = add_name_splitter(splitter, self.name)
182182
if combiner:
183183
combiner = add_name_combiner(combiner, self.name)
184-
if self._task._cont_dim:
185-
for key, val in self._task._cont_dim.items():
186-
cont_dim[f"{self.name}.{key}"] = val
184+
if self._task._container_ndim:
185+
for key, val in self._task._container_ndim.items():
186+
container_ndim[f"{self.name}.{key}"] = val
187187
other_states = self._get_upstream_states()
188188
if splitter or combiner or other_states:
189189
self._state = State(
190190
self.name,
191191
splitter=splitter,
192192
other_states=other_states,
193193
combiner=combiner,
194-
cont_dim=cont_dim,
194+
container_ndim=container_ndim,
195195
)
196196
if combiner:
197197
if not_split := [
@@ -218,7 +218,7 @@ def _get_upstream_states(self) -> dict[str, tuple["State", list[str]]]:
218218
node: Node = val._node
219219
# variables that are part of inner splitters should be treated as a containers
220220
if node.state and f"{node.name}.{val._field}" in node.state.splitter:
221-
node.state._inner_cont_dim[f"{node.name}.{val._field}"] = 1
221+
node.state._inner_container_ndim[f"{node.name}.{val._field}"] = 1
222222
# adding task_name: (task.state, [a field from the connection]
223223
if node.name not in upstream_states:
224224
upstream_states[node.name] = (node.state, [val._field])

pydra/engine/state.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
name,
8686
splitter=None,
8787
combiner=None,
88-
cont_dim=None,
88+
container_ndim=None,
8989
other_states=None,
9090
):
9191
"""
@@ -109,8 +109,8 @@ def __init__(
109109
self.splitter = splitter
110110
# temporary combiner
111111
self.combiner = combiner
112-
self.cont_dim = cont_dim or {}
113-
self._inner_cont_dim = {}
112+
self.container_ndim = container_ndim or {}
113+
self._inner_container_ndim = {}
114114
self._inputs_ind = None
115115
# if other_states, the connections have to be updated
116116
if self.other_states:
@@ -377,12 +377,12 @@ def prev_state_splitter_rpn_compact(self):
377377
return self._prev_state_splitter_rpn_compact
378378

379379
@property
380-
def cont_dim_all(self):
381-
# adding inner_cont_dim to the general container_dimension provided by the users
382-
cont_dim_all = deepcopy(self.cont_dim)
383-
for k, v in self._inner_cont_dim.items():
384-
cont_dim_all[k] = cont_dim_all.get(k, 1) + v
385-
return cont_dim_all
380+
def container_ndim_all(self):
381+
# adding inner_container_ndim to the general container_dimension provided by the users
382+
container_ndim_all = deepcopy(self.container_ndim)
383+
for k, v in self._inner_container_ndim.items():
384+
container_ndim_all[k] = container_ndim_all.get(k, 1) + v
385+
return container_ndim_all
386386

387387
@property
388388
def combiner(self):
@@ -869,7 +869,7 @@ def combiner_validation(self):
869869
def prepare_states(
870870
self,
871871
inputs: dict[str, ty.Any],
872-
cont_dim: dict[str, int] | None = None,
872+
container_ndim: dict[str, int] | None = None,
873873
):
874874
"""
875875
Prepare a full list of state indices and state values.
@@ -885,13 +885,13 @@ def prepare_states(
885885
self.combiner_validation()
886886
self.set_input_groups()
887887
self.inputs = inputs
888-
if cont_dim is not None:
889-
self.cont_dim = cont_dim
888+
if container_ndim is not None:
889+
self.container_ndim = container_ndim
890890
if self.other_states:
891891
st: State
892892
for nm, (st, _) in self.other_states.items():
893893
self.inputs.update(st.inputs)
894-
self.cont_dim.update(st.cont_dim_all)
894+
self.container_ndim.update(st.container_ndim_all)
895895

896896
self.prepare_states_ind()
897897
self.prepare_states_val()
@@ -995,7 +995,9 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
995995
def prepare_states_val(self):
996996
"""Evaluate states values having states indices."""
997997
self.states_val = list(
998-
map_splits(self.states_ind, self.inputs, cont_dim=self.cont_dim_all)
998+
map_splits(
999+
self.states_ind, self.inputs, container_ndim=self.container_ndim_all
1000+
)
9991001
)
10001002
return self.states_val
10011003

@@ -1165,8 +1167,8 @@ def _processing_terms(self, term, previous_states_ind):
11651167
var_ind, new_keys = previous_states_ind[term]
11661168
shape = (len(var_ind),)
11671169
else:
1168-
cont_dim = self.cont_dim_all.get(term, 1)
1169-
shape = input_shape(self.inputs[term], cont_dim=cont_dim)
1170+
container_ndim = self.container_ndim_all.get(term, 1)
1171+
shape = input_shape(self.inputs[term], container_ndim=container_ndim)
11701172
var_ind = range(reduce(lambda x, y: x * y, shape))
11711173
new_keys = [term]
11721174
# checking if the term is in inner_inputs
@@ -1186,7 +1188,8 @@ def _processing_terms(self, term, previous_states_ind):
11861188
def _single_op_splits(self, op_single):
11871189
"""splits function if splitter is a singleton"""
11881190
shape = input_shape(
1189-
self.inputs[op_single], cont_dim=self.cont_dim_all.get(op_single, 1)
1191+
self.inputs[op_single],
1192+
container_ndim=self.container_ndim_all.get(op_single, 1),
11901193
)
11911194
val_ind = range(reduce(lambda x, y: x * y, shape))
11921195
if op_single in self.inner_inputs:
@@ -1211,8 +1214,8 @@ def _single_op_splits(self, op_single):
12111214
def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
12121215
"""
12131216
Extracting element of the inputs taking into account
1214-
container dimension of the specific element that can be set in self.state.cont_dim.
1215-
If input name is not in cont_dim, it is assumed that the input values has
1217+
container dimension of the specific element that can be set in self.state.container_ndim.
1218+
If input name is not in container_ndim, it is assumed that the input values has
12161219
a container dimension of 1, so only the most outer dim will be used for splitting.
12171220
12181221
Parameters
@@ -1229,11 +1232,11 @@ def _get_element(self, value: ty.Any, field_name: str, ind: int) -> ty.Any:
12291232
Any
12301233
specific element of the input field
12311234
"""
1232-
if f"{self.name}.{field_name}" in self.cont_dim_all:
1235+
if f"{self.name}.{field_name}" in self.container_ndim_all:
12331236
return list(
12341237
flatten(
12351238
ensure_list(value),
1236-
max_depth=self.cont_dim_all[f"{self.name}.{field_name}"],
1239+
max_depth=self.container_ndim_all[f"{self.name}.{field_name}"],
12371240
)
12381241
)[ind]
12391242
else:
@@ -1600,15 +1603,15 @@ def iter_splits(iterable, keys):
16001603
yield dict(zip(keys, list(flatten(iter, max_depth=1000))))
16011604

16021605

1603-
def input_shape(inp, cont_dim=1):
1606+
def input_shape(inp, container_ndim=1):
16041607
"""Get input shape, depends on the container dimension, if not specify it is assumed to be 1"""
16051608
# TODO: have to be changed for inner splitter (sometimes different length)
1606-
cont_dim -= 1
1609+
container_ndim -= 1
16071610
shape = [len(inp)]
16081611
last_shape = None
16091612
for value in inp:
1610-
if isinstance(value, list) and cont_dim > 0:
1611-
cur_shape = input_shape(value, cont_dim)
1613+
if isinstance(value, list) and container_ndim > 0:
1614+
cur_shape = input_shape(value, container_ndim)
16121615
if last_shape is None:
16131616
last_shape = cur_shape
16141617
elif last_shape != cur_shape:
@@ -1828,13 +1831,15 @@ def combine_final_groups(combiner, groups, groups_stack, keys):
18281831
return keys_final, groups_final_map, groups_stack_final, combiner_all
18291832

18301833

1831-
def map_splits(split_iter, inputs, cont_dim=None):
1834+
def map_splits(split_iter, inputs, container_ndim=None):
18321835
"""generate a dictionary of inputs prescribed by the splitter."""
1833-
if cont_dim is None:
1834-
cont_dim = {}
1836+
if container_ndim is None:
1837+
container_ndim = {}
18351838
for split in split_iter:
18361839
yield {
1837-
k: list(flatten(ensure_list(inputs[k]), max_depth=cont_dim.get(k, None)))[v]
1840+
k: list(
1841+
flatten(ensure_list(inputs[k]), max_depth=container_ndim.get(k, None))
1842+
)[v]
18381843
for k, v in split.items()
18391844
}
18401845

pydra/engine/submitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def __call__(
215215
name="outer_split",
216216
splitter=deepcopy(task._splitter),
217217
combiner=deepcopy(task._combiner),
218-
cont_dim=deepcopy(task._cont_dim),
218+
container_ndim=deepcopy(task._container_ndim),
219219
)
220220

221221
def wrap_type(tp):

0 commit comments

Comments
 (0)