Skip to content

Commit f51063e

Browse files
committed
moved Node.cont_dim into State
1 parent 32931d3 commit f51063e

File tree

2 files changed

+11
-21
lines changed

2 files changed

+11
-21
lines changed

pydra/engine/node.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,21 +150,6 @@ def lzout(self) -> OutputType:
150150
self._lzout = outputs
151151
return outputs
152152

153-
@property
154-
def cont_dim(self):
155-
# adding inner_cont_dim to the general container_dimension provided by the users
156-
cont_dim_all = deepcopy(self._cont_dim)
157-
for k, v in self.state._inner_cont_dim.items():
158-
cont_dim_all[k] = cont_dim_all.get(k, 1) + v
159-
return cont_dim_all
160-
161-
@cont_dim.setter
162-
def cont_dim(self, cont_dim):
163-
if cont_dim is None:
164-
self._cont_dim = {}
165-
else:
166-
self._cont_dim = cont_dim
167-
168153
@property
169154
def splitter(self):
170155
if not self._state:

pydra/engine/state.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,14 @@ def prev_state_splitter_rpn_compact(self):
377377
"""
378378
return self._prev_state_splitter_rpn_compact
379379

380+
@property
381+
def cont_dim_all(self):
382+
# adding inner_cont_dim to the general container_dimension provided by the users
383+
cont_dim_all = deepcopy(self.cont_dim)
384+
for k, v in self._inner_cont_dim.items():
385+
cont_dim_all[k] = cont_dim_all.get(k, 1) + v
386+
return cont_dim_all
387+
380388
@property
381389
def combiner(self):
382390
"""the combiner associated to the state."""
@@ -858,7 +866,6 @@ def combiner_validation(self):
858866
def prepare_states(
859867
self,
860868
inputs: dict[str, ty.Any],
861-
cont_dim: dict[str, int] | None = None,
862869
):
863870
"""
864871
Prepare a full list of state indices and state values.
@@ -874,8 +881,6 @@ def prepare_states(
874881
self.combiner_validation()
875882
self.set_input_groups()
876883
self.inputs = inputs
877-
if not self.cont_dim:
878-
self.cont_dim = cont_dim or {}
879884
if self.other_states:
880885
st: State
881886
for nm, (st, _) in self.other_states.items():
@@ -986,7 +991,7 @@ def prepare_states_combined_ind(self, elements_to_remove_comb):
986991
def prepare_states_val(self):
987992
"""Evaluate states values having states indices."""
988993
self.states_val = list(
989-
hlpst.map_splits(self.states_ind, self.inputs, cont_dim=self.cont_dim)
994+
hlpst.map_splits(self.states_ind, self.inputs, cont_dim=self.cont_dim_all)
990995
)
991996
return self.states_val
992997

@@ -1156,7 +1161,7 @@ def _processing_terms(self, term, previous_states_ind):
11561161
var_ind, new_keys = previous_states_ind[term]
11571162
shape = (len(var_ind),)
11581163
else:
1159-
cont_dim = self.cont_dim.get(term, 1)
1164+
cont_dim = self.cont_dim_all.get(term, 1)
11601165
shape = hlpst.input_shape(self.inputs[term], cont_dim=cont_dim)
11611166
var_ind = range(reduce(lambda x, y: x * y, shape))
11621167
new_keys = [term]
@@ -1177,7 +1182,7 @@ def _processing_terms(self, term, previous_states_ind):
11771182
def _single_op_splits(self, op_single):
11781183
"""splits function if splitter is a singleton"""
11791184
shape = hlpst.input_shape(
1180-
self.inputs[op_single], cont_dim=self.cont_dim.get(op_single, 1)
1185+
self.inputs[op_single], cont_dim=self.cont_dim_all.get(op_single, 1)
11811186
)
11821187
val_ind = range(reduce(lambda x, y: x * y, shape))
11831188
if op_single in self.inner_inputs:

0 commit comments

Comments
 (0)