4040 from pydra .engine .graph import DiGraph
4141 from pydra .engine .submitter import NodeExecution
4242 from pydra .engine .lazy import LazyOutField
43- from pydra .engine .task import ShellTask
4443 from pydra .engine .core import Workflow
4544 from pydra .engine .environments import Environment
4645 from pydra .engine .workers import Worker
4746
4847
48+ DefType = ty .TypeVar ("DefType" , bound = "TaskDef" )
49+
50+
4951def is_set (value : ty .Any ) -> bool :
5052 """Check if a value has been set."""
5153 return value not in (attrs .NOTHING , EMPTY )
@@ -372,7 +374,7 @@ def _compute_hashes(self) -> ty.Tuple[bytes, ty.Dict[str, bytes]]:
372374 }
373375 return hash_function (sorted (field_hashes .items ())), field_hashes
374376
375- def _retrieve_values (self , wf , state_index = None ):
377+ def _resolve_lazy_fields (self , wf , state_index = None ):
376378 """Parse output results."""
377379 temp_values = {}
378380 for field in attrs_fields (self ):
@@ -482,7 +484,7 @@ class Runtime:
482484class Result (ty .Generic [OutputsType ]):
483485 """Metadata regarding the outputs of processing."""
484486
485- task : "Task"
487+ task : "Task[DefType] "
486488 outputs : OutputsType | None = None
487489 runtime : Runtime | None = None
488490 errored : bool = False
@@ -548,13 +550,13 @@ class RuntimeSpec:
548550class PythonOutputs (TaskOutputs ):
549551
550552 @classmethod
551- def _from_task (cls , task : "Task" ) -> Self :
553+ def _from_task (cls , task : "Task[PythonDef] " ) -> Self :
552554 """Collect the outputs of a task from a combination of the provided inputs,
553555 the objects in the output directory, and the stdout and stderr of the process.
554556
555557 Parameters
556558 ----------
557- task : Task
559+ task : Task[PythonDef]
558560 The task whose outputs are being collected.
559561 outputs_dict : dict[str, ty.Any]
560562 The outputs of the task, as a dictionary
@@ -575,7 +577,7 @@ def _from_task(cls, task: "Task") -> Self:
575577
576578class PythonDef (TaskDef [PythonOutputsType ]):
577579
578- def _run (self , task : "Task" ) -> None :
580+ def _run (self , task : "Task[PythonDef] " ) -> None :
579581 # Prepare the inputs to the function
580582 inputs = attrs_values (self )
581583 del inputs ["function" ]
@@ -602,12 +604,12 @@ def _run(self, task: "Task") -> None:
602604class WorkflowOutputs (TaskOutputs ):
603605
604606 @classmethod
605- def _from_task (cls , task : "Task" ) -> Self :
607+ def _from_task (cls , task : "Task[WorkflowDef] " ) -> Self :
606608 """Collect the outputs of a workflow task from the outputs of the nodes in the
607609
608610 Parameters
609611 ----------
610- task : Task
612+ task : Task[WorfklowDef]
611613 The task whose outputs are being collected.
612614
613615 Returns
@@ -659,12 +661,13 @@ class WorkflowDef(TaskDef[WorkflowOutputsType]):
659661
660662 _constructed = attrs .field (default = None , init = False )
661663
662- def _run (self , task : "Task" ) -> None :
664+ def _run (self , task : "Task[WorkflowDef] " ) -> None :
663665 """Run the workflow."""
664- if task .submitter .worker .is_async :
665- task .submitter .expand_workflow_async (task )
666- else :
667- task .submitter .expand_workflow (task )
666+ task .submitter .expand_workflow (task )
667+
668+ async def _run_async (self , task : "Task[WorkflowDef]" ) -> None :
669+ """Run the workflow asynchronously."""
670+ await task .submitter .expand_workflow_async (task )
668671
669672 def construct (self ) -> "Workflow" :
670673 from pydra .engine .core import Workflow
@@ -688,7 +691,7 @@ class ShellOutputs(TaskOutputs):
688691 stderr : str = shell .out (help = STDERR_HELP )
689692
690693 @classmethod
691- def _from_task (cls , task : "ShellTask " ) -> Self :
694+ def _from_task (cls , task : "Task[ShellDef] " ) -> Self :
692695 """Collect the outputs of a shell process from a combination of the provided inputs,
693696 the objects in the output directory, and the stdout and stderr of the process.
694697
@@ -784,7 +787,7 @@ def _required_fields_satisfied(cls, fld: shell.out, inputs: "ShellDef") -> bool:
784787 def _resolve_value (
785788 cls ,
786789 fld : "shell.out" ,
787- task : "Task" ,
790+ task : "Task[DefType] " ,
788791 ) -> ty .Any :
789792 """Collect output file if metadata specified."""
790793 from pydra .design import shell
@@ -842,7 +845,7 @@ class ShellDef(TaskDef[ShellOutputsType]):
842845
843846 RESERVED_FIELD_NAMES = TaskDef .RESERVED_FIELD_NAMES + ("cmdline" ,)
844847
845- def _run (self , task : "Task" ) -> None :
848+ def _run (self , task : "Task[ShellDef] " ) -> None :
846849 """Run the shell command."""
847850 task .return_values = task .environment .execute (task )
848851
0 commit comments