Skip to content

Commit 03c7438

Browse files
committed
touching up typing of tasks to include TaskDef template
1 parent e5c2556 commit 03c7438

File tree

7 files changed

+77
-177
lines changed

7 files changed

+77
-177
lines changed

pydra/engine/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ async def run_async(self, rerun: bool = False):
409409
self.audit.start_audit(odir=self.output_dir)
410410
try:
411411
self.audit.monitor()
412-
await self.definition._run(self)
412+
await self.definition._run_async(self)
413413
result.outputs = self.definition.Outputs._from_task(self)
414414
except Exception:
415415
etype, eval, etr = sys.exc_info()

pydra/engine/environments.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
if ty.TYPE_CHECKING:
66
from pydra.engine.core import Task
7+
from pydra.engine.specs import ShellDef
78

89

910
class Environment:
@@ -17,7 +18,7 @@ class Environment:
1718
def setup(self):
1819
pass
1920

20-
def execute(self, task: "Task") -> dict[str, ty.Any]:
21+
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
2122
"""
2223
Execute the task in the environment.
2324
@@ -42,7 +43,7 @@ class Native(Environment):
4243
Native environment, i.e. the tasks are executed in the current python environment.
4344
"""
4445

45-
def execute(self, task: "Task") -> dict[str, ty.Any]:
46+
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
4647
keys = ["return_code", "stdout", "stderr"]
4748
values = execute(task.definition._command_args())
4849
output = dict(zip(keys, values))
@@ -90,7 +91,7 @@ def bind(self, loc, mode="ro"):
9091
class Docker(Container):
9192
"""Docker environment."""
9293

93-
def execute(self, task: "Task") -> dict[str, ty.Any]:
94+
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
9495
docker_img = f"{self.image}:{self.tag}"
9596
# mounting all input locations
9697
mounts = task.definition._get_bindings(root=self.root)
@@ -125,7 +126,7 @@ def execute(self, task: "Task") -> dict[str, ty.Any]:
125126
class Singularity(Container):
126127
"""Singularity environment."""
127128

128-
def execute(self, task: "Task") -> dict[str, ty.Any]:
129+
def execute(self, task: "Task[ShellDef]") -> dict[str, ty.Any]:
129130
singularity_img = f"{self.image}:{self.tag}"
130131
# mounting all input locations
131132
mounts = task.definition._get_bindings(root=self.root)

pydra/engine/helpers.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
PYDRA_ATTR_METADATA = "__PYDRA_METADATA__"
2727

28+
DefType = ty.TypeVar("DefType", bound="TaskDef")
29+
2830

2931
def attrs_fields(definition, exclude_names=()) -> list[attrs.Attribute]:
3032
"""Get the fields of a definition, excluding some names."""
@@ -132,7 +134,7 @@ def load_result(checksum, cache_locations):
132134
def save(
133135
task_path: Path,
134136
result: "Result | None" = None,
135-
task: "Task | None" = None,
137+
task: "Task[DefType] | None" = None,
136138
name_prefix: str = None,
137139
) -> None:
138140
"""
@@ -449,7 +451,7 @@ def load_and_run(task_pkl: Path, rerun: bool = False) -> Path:
449451
from .specs import Result
450452

451453
try:
452-
task: Task = load_task(task_pkl=task_pkl)
454+
task: Task[DefType] = load_task(task_pkl=task_pkl)
453455
except Exception:
454456
if task_pkl.parent.exists():
455457
etype, eval, etr = sys.exc_info()

pydra/engine/lazy.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from .graph import DiGraph
1111
from .submitter import NodeExecution
1212
from .core import Task, Workflow
13+
from .specs import TaskDef
1314

1415

1516
T = ty.TypeVar("T")
17+
DefType = ty.TypeVar("DefType", bound="TaskDef")
1618

1719
TypeOrAny = ty.Union[type, ty.Any]
1820

@@ -150,7 +152,7 @@ def get_value(
150152
task = graph.node(self.node.name).task(state_index)
151153
_, split_depth = TypeParser.strip_splits(self.type)
152154

153-
def get_nested(task: "Task", depth: int):
155+
def get_nested(task: "Task[DefType]", depth: int):
154156
if isinstance(task, StateArray):
155157
val = [get_nested(task=t, depth=depth - 1) for t in task]
156158
if depth:

pydra/engine/specs.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@
4040
from pydra.engine.graph import DiGraph
4141
from pydra.engine.submitter import NodeExecution
4242
from pydra.engine.lazy import LazyOutField
43-
from pydra.engine.task import ShellTask
4443
from pydra.engine.core import Workflow
4544
from pydra.engine.environments import Environment
4645
from pydra.engine.workers import Worker
4746

4847

48+
DefType = ty.TypeVar("DefType", bound="TaskDef")
49+
50+
4951
def is_set(value: ty.Any) -> bool:
5052
"""Check if a value has been set."""
5153
return value not in (attrs.NOTHING, EMPTY)
@@ -372,7 +374,7 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
372374
}
373375
return hash_function(sorted(field_hashes.items())), field_hashes
374376

375-
def _retrieve_values(self, wf, state_index=None):
377+
def _resolve_lazy_fields(self, wf, state_index=None):
376378
"""Parse output results."""
377379
temp_values = {}
378380
for field in attrs_fields(self):
@@ -482,7 +484,7 @@ class Runtime:
482484
class Result(ty.Generic[OutputsType]):
483485
"""Metadata regarding the outputs of processing."""
484486

485-
task: "Task"
487+
task: "Task[DefType]"
486488
outputs: OutputsType | None = None
487489
runtime: Runtime | None = None
488490
errored: bool = False
@@ -548,13 +550,13 @@ class RuntimeSpec:
548550
class PythonOutputs(TaskOutputs):
549551

550552
@classmethod
551-
def _from_task(cls, task: "Task") -> Self:
553+
def _from_task(cls, task: "Task[PythonDef]") -> Self:
552554
"""Collect the outputs of a task from a combination of the provided inputs,
553555
the objects in the output directory, and the stdout and stderr of the process.
554556
555557
Parameters
556558
----------
557-
task : Task
559+
task : Task[PythonDef]
558560
The task whose outputs are being collected.
559561
outputs_dict : dict[str, ty.Any]
560562
The outputs of the task, as a dictionary
@@ -575,7 +577,7 @@ def _from_task(cls, task: "Task") -> Self:
575577

576578
class PythonDef(TaskDef[PythonOutputsType]):
577579

578-
def _run(self, task: "Task") -> None:
580+
def _run(self, task: "Task[PythonDef]") -> None:
579581
# Prepare the inputs to the function
580582
inputs = attrs_values(self)
581583
del inputs["function"]
@@ -602,12 +604,12 @@ def _run(self, task: "Task") -> None:
602604
class WorkflowOutputs(TaskOutputs):
603605

604606
@classmethod
605-
def _from_task(cls, task: "Task") -> Self:
607+
def _from_task(cls, task: "Task[WorkflowDef]") -> Self:
606608
"""Collect the outputs of a workflow task from the outputs of the nodes in the
607609
608610
Parameters
609611
----------
610-
task : Task
612+
task : Task[WorfklowDef]
611613
The task whose outputs are being collected.
612614
613615
Returns
@@ -659,12 +661,13 @@ class WorkflowDef(TaskDef[WorkflowOutputsType]):
659661

660662
_constructed = attrs.field(default=None, init=False)
661663

662-
def _run(self, task: "Task") -> None:
664+
def _run(self, task: "Task[WorkflowDef]") -> None:
663665
"""Run the workflow."""
664-
if task.submitter.worker.is_async:
665-
task.submitter.expand_workflow_async(task)
666-
else:
667-
task.submitter.expand_workflow(task)
666+
task.submitter.expand_workflow(task)
667+
668+
async def _run_async(self, task: "Task[WorkflowDef]") -> None:
669+
"""Run the workflow asynchronously."""
670+
await task.submitter.expand_workflow_async(task)
668671

669672
def construct(self) -> "Workflow":
670673
from pydra.engine.core import Workflow
@@ -688,7 +691,7 @@ class ShellOutputs(TaskOutputs):
688691
stderr: str = shell.out(help=STDERR_HELP)
689692

690693
@classmethod
691-
def _from_task(cls, task: "ShellTask") -> Self:
694+
def _from_task(cls, task: "Task[ShellDef]") -> Self:
692695
"""Collect the outputs of a shell process from a combination of the provided inputs,
693696
the objects in the output directory, and the stdout and stderr of the process.
694697
@@ -784,7 +787,7 @@ def _required_fields_satisfied(cls, fld: shell.out, inputs: "ShellDef") -> bool:
784787
def _resolve_value(
785788
cls,
786789
fld: "shell.out",
787-
task: "Task",
790+
task: "Task[DefType]",
788791
) -> ty.Any:
789792
"""Collect output file if metadata specified."""
790793
from pydra.design import shell
@@ -842,7 +845,7 @@ class ShellDef(TaskDef[ShellOutputsType]):
842845

843846
RESERVED_FIELD_NAMES = TaskDef.RESERVED_FIELD_NAMES + ("cmdline",)
844847

845-
def _run(self, task: "Task") -> None:
848+
def _run(self, task: "Task[ShellDef]") -> None:
846849
"""Run the shell command."""
847850
task.return_values = task.environment.execute(task)
848851

0 commit comments

Comments
 (0)