|
18 | 18 | attrs_values, |
19 | 19 | ) |
20 | 20 | from pydra.utils.hash import PersistentCache |
21 | | -from pydra.utils.typing import StateArray |
22 | 21 | from pydra.engine.lazy import LazyField |
23 | 22 | from .audit import Audit |
24 | 23 | from .core import Task |
@@ -536,7 +535,6 @@ def __init__( |
536 | 535 | self.queued = {} |
537 | 536 | self.running = {} # Not used in logic, but may be useful for progress tracking |
538 | 537 | self.unrunnable = defaultdict(list) |
539 | | - self.state_names = self.node.state.names if self.node.state else [] |
540 | 538 | self.workflow = workflow |
541 | 539 | self.graph = None |
542 | 540 |
|
@@ -566,47 +564,19 @@ def tasks(self) -> ty.Generator["Task[DefType]", None, None]: |
566 | 564 | raise RuntimeError("Tasks have not been generated") |
567 | 565 | return self._tasks.values() |
568 | 566 |
|
569 | | - def get_jobs(self, final_index: int | None = None) -> "Task | StateArray[Task]": |
570 | | - """Get the jobs that match a given state index. |
571 | | -
|
572 | | - Parameters |
573 | | - ---------- |
574 | | - final_index : int, optional |
575 | | - The index of the output state array (i.e. after any combinations) of the |
576 | | - job to get, by default None |
577 | | -
|
578 | | - Returns |
579 | | - ------- |
580 | | - matching : Task | StateArray[Task] |
581 | | - The task or tasks that match the given index |
582 | | - """ |
583 | | - if not self.tasks: # No jobs, return empty state array |
584 | | - return StateArray() |
585 | | - if not self.node.state: # Return the singular job |
586 | | - return self._tasks[None] |
587 | | - if final_index is None: # return all jobs in a state array |
588 | | - return StateArray(self._tasks.values()) |
589 | | - if not self.node.state.combiner: # Select the job that matches the index |
590 | | - return self._tasks[final_index] |
591 | | - # Get a slice of the tasks that match the given index of the state array of the |
592 | | - # combined values |
593 | | - final_index = set(self.node.state.states_ind_final[final_index].items()) |
594 | | - return StateArray( |
595 | | - self._tasks[i] |
596 | | - for i, ind in enumerate(self.node.state.states_ind) |
597 | | - if set(ind.items()).issuperset(final_index) |
598 | | - ) |
599 | | - |
600 | 567 | def start(self) -> None: |
601 | 568 | """Prepare the execution node so that it can be processed""" |
602 | 569 | self._tasks = {} |
603 | 570 | if self.state: |
604 | 571 | values = {} |
605 | 572 | for name, value in self.node.state_values.items(): |
606 | | - if name in self.node.state.names and isinstance(value, LazyField): |
607 | | - values[name] = value._get_value( |
608 | | - workflow=self.workflow, graph=self.graph |
609 | | - ) |
| 573 | + if name in self.node.state.names: |
| 574 | + if isinstance(value, LazyField): |
| 575 | + values[name] = value._get_value( |
| 576 | + workflow=self.workflow, graph=self.graph |
| 577 | + ) |
| 578 | + else: |
| 579 | + values[name] = value |
610 | 580 | self.state.prepare_states(values) |
611 | 581 | self.state.prepare_inputs() |
612 | 582 | # Generate the tasks |
|
0 commit comments