Skip to content

Commit b0db082

Browse files
committed
debugging shelltask
1 parent da92150 commit b0db082

File tree

8 files changed

+103
-120
lines changed

8 files changed

+103
-120
lines changed

pydra/design/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
350350
if not atr.help:
351351
atr.help = helps.get(atr_name, "")
352352
elif atr_name in type_hints:
353+
if atr_name.startswith("_"):
354+
continue
353355
if atr_name in fields_dict:
354356
fields_dict[atr_name].type = type_hints[atr_name]
355357
elif auto_attribs:
@@ -361,6 +363,8 @@ def get_fields(klass, field_type, auto_attribs, helps) -> dict[str, Field]:
361363
)
362364
if auto_attribs:
363365
for atr_name, type_ in type_hints.items():
366+
if atr_name.startswith("_"):
367+
continue
364368
if atr_name not in list(fields_dict) + ["Task", "Outputs"]:
365369
fields_dict[atr_name] = field_type(
366370
name=atr_name, type=type_, help=helps.get(atr_name, "")

pydra/design/shell.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
MultiInputObj,
2929
TypeParser,
3030
is_optional,
31-
optional_type,
3231
)
3332

3433
if ty.TYPE_CHECKING:
@@ -85,7 +84,7 @@ class arg(Arg):
8584
If nothing is provided the field will be inserted between all fields with
8685
nonnegative positions and fields with negative positions.
8786
sep: str, optional
88-
A separator if a list is provided as a value.
87+
A separator if a sequence type is provided as a value, by default " ".
8988
container_path: bool, optional
9089
If True a path will be consider as a path inside the container (and not as a
9190
local path, by default it is False
@@ -99,45 +98,11 @@ class arg(Arg):
9998

10099
argstr: str | None = ""
101100
position: int | None = None
102-
sep: str | None = attrs.field()
101+
sep: str = " "
103102
allowed_values: list | None = None
104103
container_path: bool = False # IS THIS STILL USED??
105104
formatter: ty.Callable | None = None
106105

107-
@sep.default
108-
def _sep_default(self):
109-
return " " if self.type is tuple or ty.get_origin(self.type) is tuple else None
110-
111-
@sep.validator
112-
def _validate_sep(self, _, sep):
113-
if self.type is MultiInputObj:
114-
tp = ty.Any
115-
elif ty.get_origin(self.type) is MultiInputObj:
116-
tp = ty.get_args(self.type)[0]
117-
else:
118-
tp = self.type
119-
if is_optional(tp):
120-
tp = optional_type(tp)
121-
if tp is ty.Any:
122-
return
123-
origin = ty.get_origin(tp) or tp
124-
125-
if (
126-
inspect.isclass(origin)
127-
and issubclass(origin, ty.Sequence)
128-
and tp is not str
129-
):
130-
if sep is None and not self.readonly:
131-
raise ValueError(
132-
f"A value to 'sep' must be provided when type is iterable {tp} "
133-
f"for field {self.name!r}"
134-
)
135-
elif sep is not None:
136-
raise ValueError(
137-
f"sep ({sep!r}) can only be provided when type is iterable {tp} "
138-
f"for field {self.name!r}"
139-
)
140-
141106

142107
@attrs.define(kw_only=True)
143108
class out(Out):

pydra/engine/lazy.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def _get_value(
152152
the resolved value of the lazy-field
153153
"""
154154

155-
jobs = graph.node(self._node.name).matching_jobs(state_index)
155+
jobs = graph.node(self._node.name).get_jobs(state_index)
156156

157157
def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
158158
if job.errored:
@@ -184,7 +184,9 @@ def retrieve_from_job(job: "Task[DefType]") -> ty.Any:
184184
val = self._apply_cast(val)
185185
return val
186186

187-
if not self._node.state or not self._node.state.depth(before_combine=True):
187+
if not isinstance(jobs, StateArray):
188+
return retrieve_from_job(jobs)
189+
elif not self._node.state or not self._node.state.depth(before_combine=True):
188190
assert len(jobs) == 1
189191
return retrieve_from_job(jobs[0])
190192
elif not self._node.state.keys_final: # all states are combined over

pydra/engine/specs.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ def _command_args(self, values: dict[str, ty.Any]) -> list[str]:
10891089
self._check_resolved()
10901090
self._check_rules()
10911091
# Drop none/empty values and optional path fields that are set to false
1092+
values = copy(values) # Create a copy so we can drop items from the dictionary
10921093
for field in list_fields(self):
10931094
fld_value = values[field.name]
10941095
if fld_value is None or (is_multi_input(field.type) and fld_value == []):
@@ -1185,7 +1186,7 @@ def _command_pos_args(
11851186
call_args_val = {}
11861187
for argnm in call_args.args:
11871188
if argnm == "field":
1188-
call_args_val[argnm] = value
1189+
call_args_val[argnm] = field
11891190
elif argnm == "inputs":
11901191
call_args_val[argnm] = values
11911192
else:
@@ -1223,31 +1224,31 @@ def _format_arg(self, field: shell.arg, values: dict[str, ty.Any]) -> list[str]:
12231224
and isinstance(value, ty.Iterable)
12241225
and not isinstance(value, (str, bytes))
12251226
):
1226-
field.argstr = field.argstr.replace("...", "")
1227+
argstr = field.argstr.replace("...", "")
12271228
# if argstr has a more complex form, with "{input_field}"
1228-
if "{" in field.argstr and "}" in field.argstr:
1229+
if "{" in argstr and "}" in argstr:
12291230
argstr_formatted_l = []
12301231
for val in value:
12311232
split_values = copy(values)
12321233
split_values[field.name] = val
1233-
argstr_f = argstr_formatting(field.argstr, split_values)
1234+
argstr_f = argstr_formatting(argstr, split_values)
12341235
argstr_formatted_l.append(f" {argstr_f}")
1235-
cmd_el_str = field.sep.join(argstr_formatted_l)
1236+
cmd_el_str = " ".join(argstr_formatted_l)
12361237
else: # argstr has a simple form, e.g. "-f", or "--f"
1237-
cmd_el_str = field.sep.join([f" {field.argstr} {val}" for val in value])
1238+
cmd_el_str = " ".join([f" {argstr} {val}" for val in value])
12381239
else:
12391240
# in case there are ... when input is not a list
1240-
field.argstr = field.argstr.replace("...", "")
1241+
argstr = field.argstr.replace("...", "")
12411242
if isinstance(value, ty.Iterable) and not isinstance(value, (str, bytes)):
12421243
cmd_el_str = field.sep.join([str(val) for val in value])
12431244
value = cmd_el_str
12441245
# if argstr has a more complex form, with "{input_field}"
1245-
if "{" in field.argstr and "}" in field.argstr:
1246-
cmd_el_str = field.argstr.replace(f"{{{field.name}}}", str(value))
1246+
if "{" in argstr and "}" in argstr:
1247+
cmd_el_str = argstr.replace(f"{{{field.name}}}", str(value))
12471248
cmd_el_str = argstr_formatting(cmd_el_str, values)
12481249
else: # argstr has a simple form, e.g. "-f", or "--f"
12491250
if value:
1250-
cmd_el_str = f"{field.argstr} {value}"
1251+
cmd_el_str = f"{argstr} {value}"
12511252
else:
12521253
cmd_el_str = ""
12531254
return split_cmd(cmd_el_str)

pydra/engine/state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -879,9 +879,6 @@ def prepare_states(
879879
if self.other_states:
880880
st: State
881881
for nm, (st, _) in self.other_states.items():
882-
# I think now this if is never used
883-
if not hasattr(st, "states_ind"):
884-
st.prepare_states(self.inputs, cont_dim=cont_dim)
885882
self.inputs.update(st.inputs)
886883
self.cont_dim.update(st.cont_dim)
887884

pydra/engine/submitter.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def __init__(
538538
self.unrunnable = defaultdict(list)
539539
# Prepare the state to be run
540540
if node.state:
541-
self.state = deepcopy(node.state)
541+
self.state = node.state
542542
self.state.prepare_states(self.node.state_values)
543543
self.state.prepare_inputs()
544544
else:
@@ -564,34 +564,43 @@ def _definition(self) -> "Node":
564564
return self.node._definition
565565

566566
@property
567-
def tasks(self) -> ty.Iterable["Task[DefType]"]:
567+
def tasks(self) -> ty.Generator["Task[DefType]", None, None]:
568568
if self._tasks is None:
569569
self._tasks = {t.state_index: t for t in self._generate_tasks()}
570570
return self._tasks.values()
571571

572-
def matching_jobs(self, index: int | None = None) -> "StateArray[Task]":
572+
def get_jobs(
573+
self, index: int | None = None, as_array: bool = False
574+
) -> "Task | StateArray[Task]":
573575
"""Get the jobs that match a given state index.
574576
575577
Parameters
576578
----------
577579
index : int, optional
578580
The index of the state of the task to get, by default None
581+
as_array : bool, optional
582+
Whether to return the tasks in a state-array object, by default if the index
583+
matches
579584
580585
Returns
581586
-------
582-
matching : StateArray[Task]
583-
The tasks that match the given index
587+
matching : Task | StateArray[Task]
588+
The task or tasks that match the given index
584589
"""
585590
matching = StateArray()
586591
if self.tasks:
587592
try:
588-
matching.append(self._tasks[index])
593+
task = self._tasks[index]
589594
except KeyError:
595+
if index is None:
596+
return StateArray(self._tasks.values())
590597
# Select matching tasks and return them in nested state-array objects
591598
for ind, task in self._tasks.items():
592-
if ind.matches(index):
593-
matching.append(task)
594-
599+
matching.append(task)
600+
else:
601+
if not as_array:
602+
return task
603+
matching.append(task)
595604
return matching
596605

597606
@property
@@ -657,7 +666,7 @@ def _generate_tasks(self) -> ty.Iterable["Task[DefType]"]:
657666
name=self.node.name,
658667
)
659668
else:
660-
for index, split_defn in self._split_definition().items():
669+
for index, split_defn in enumerate(self._split_definition()):
661670
yield Task(
662671
definition=split_defn,
663672
submitter=self.submitter,
@@ -707,7 +716,7 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
707716
if not self.node.state:
708717
return {None: self.node._definition}
709718
split_defs = []
710-
for i, input_ind in enumerate(self.node.state.inputs_ind):
719+
for input_ind in self.node.state.inputs_ind:
711720
resolved = {}
712721
for inpt_name in set(self.node.input_names):
713722
value = getattr(self._definition, inpt_name)
@@ -716,13 +725,13 @@ def _split_definition(self) -> dict[int, "TaskDef[OutputType]"]:
716725
resolved[inpt_name] = value._get_value(
717726
workflow=self.workflow,
718727
graph=self.graph,
719-
state_index=i,
728+
state_index=input_ind[state_key],
720729
)
721730
elif state_key in input_ind:
722731
resolved[inpt_name] = self.node.state._get_element(
723732
value=value,
724733
field_name=inpt_name,
725-
ind=i,
734+
ind=input_ind[state_key],
726735
)
727736
split_defs.append(attrs.evolve(self.node._definition, **resolved))
728737
return split_defs
@@ -754,7 +763,8 @@ def get_runnable_tasks(self, graph: DiGraph) -> list["Task[DefType]"]:
754763
pred: NodeExecution
755764
is_runnable = True
756765
for pred in graph.predecessors[self.node.name]:
757-
pred_inds = [j.state_index for j in pred.matching_jobs(index)]
766+
pred_jobs: StateArray[Task] = pred.get_jobs(index, as_array=True)
767+
pred_inds = [j.state_index for j in pred_jobs]
758768
if not all(i in pred.successful for i in pred_inds):
759769
is_runnable = False
760770
blocked = True

0 commit comments

Comments
 (0)