4343)
4444from .helpers_file import copy_nested_files , template_update
4545from pydra .utils .messenger import AuditFlag
46+ from pydra .engine .environments import Environment , Native
4647
4748logger = logging .getLogger ("pydra" )
4849
@@ -85,7 +86,8 @@ class Task(ty.Generic[DefType]):
8586
8687 name : str
8788 definition : DefType
88- submitter : "Submitter"
89+ submitter : "Submitter | None"
90+ environment : "Environment | None"
8991 state_index : state .StateIndex
9092
9193 _inputs : dict [str , ty .Any ] | None = None
@@ -95,6 +97,7 @@ def __init__(
9597 definition : DefType ,
9698 submitter : "Submitter" ,
9799 name : str ,
100+ environment : "Environment | None" = None ,
98101 state_index : "state.StateIndex | None" = None ,
99102 ):
100103 """
@@ -121,14 +124,16 @@ def __init__(
121124 if state_index is None :
122125 state_index = state .StateIndex ()
123126
124- self .definition = definition
127+ # Copy the definition, so lazy fields can be resolved and replaced at runtime
128+ self .definition = copy (definition )
129+ # We save the submitter is the definition is a workflow otherwise we don't
130+ # so the task can be pickled
131+ self .submitter = submitter if is_workflow (definition ) else None
132+ self .environment = environment if environment is not None else Native ()
125133 self .name = name
126134 self .state_index = state_index
127135
128- # checking if metadata is set properly
129- self .definition ._check_resolved ()
130- self .definition ._check_rules ()
131- self ._output = {}
136+ self .return_values = {}
132137 self ._result = {}
133138 # flag that says if node finished all jobs
134139 self ._done = False
@@ -151,6 +156,11 @@ def __init__(
151156 def cache_dir (self ):
152157 return self ._cache_dir
153158
159+ @property
160+ def is_async (self ) -> bool :
161+ """Check to see if the task should be run asynchronously."""
162+ return self .submitter .worker .is_async and is_workflow (self .definition )
163+
154164 @cache_dir .setter
155165 def cache_dir (self , path : os .PathLike ):
156166 self ._cache_dir = Path (path )
@@ -315,6 +325,21 @@ def _populate_filesystem(self):
315325 self .output_dir .mkdir (parents = False , exist_ok = self .can_resume )
316326
317327 def run (self , rerun : bool = False ):
328+ """Prepare the task working directory, execute the task definition, and save the
329+ results.
330+
331+ Parameters
332+ ----------
333+ rerun : bool
334+ If True, the task will be re-run even if a result already exists. Will
335+ propagated to all tasks within workflow tasks.
336+ """
337+ # TODO: After these changes have been merged, will refactor this function and
338+ # run_async to use common helper methods for pre/post run tasks
339+
340+ # checking if the definition is fully resolved and ready to run
341+ self .definition ._check_resolved ()
342+ self .definition ._check_rules ()
318343 self .hooks .pre_run (self )
319344 logger .debug (
320345 "'%s' is attempting to acquire lock on %s" , self .name , self .lockfile
@@ -334,8 +359,8 @@ def run(self, rerun: bool = False):
334359 self .audit .audit_task (task = self )
335360 try :
336361 self .audit .monitor ()
337- run_outputs = self .definition ._run (self )
338- result .outputs = self .definition .Outputs .from_task (self , run_outputs )
362+ self .definition ._run (self )
363+ result .outputs = self .definition .Outputs ._from_task (self )
339364 except Exception :
340365 etype , eval , etr = sys .exc_info ()
341366 traceback = format_exception (etype , eval , etr )
@@ -355,6 +380,57 @@ def run(self, rerun: bool = False):
355380 self ._check_for_hash_changes ()
356381 return result
357382
383+ async def run_async (self , rerun : bool = False ):
384+ """Prepare the task working directory, execute the task definition asynchronously,
385+ and save the results. NB: only workflows are run asynchronously at the moment.
386+
387+ Parameters
388+ ----------
389+ rerun : bool
390+ If True, the task will be re-run even if a result already exists. Will
391+ propagated to all tasks within workflow tasks.
392+ """
393+ # checking if the definition is fully resolved and ready to run
394+ self .definition ._check_resolved ()
395+ self .definition ._check_rules ()
396+ self .hooks .pre_run (self )
397+ logger .debug (
398+ "'%s' is attempting to acquire lock on %s" , self .name , self .lockfile
399+ )
400+ async with PydraFileLock (self .lockfile ):
401+ if not rerun :
402+ result = self .result ()
403+ if result is not None and not result .errored :
404+ return result
405+ cwd = os .getcwd ()
406+ self ._populate_filesystem ()
407+ result = Result (outputs = None , runtime = None , errored = False , task = self )
408+ self .hooks .pre_run_task (self )
409+ self .audit .start_audit (odir = self .output_dir )
410+ try :
411+ self .audit .monitor ()
412+ await self .definition ._run (self )
413+ result .outputs = self .definition .Outputs ._from_task (self )
414+ except Exception :
415+ etype , eval , etr = sys .exc_info ()
416+ traceback = format_exception (etype , eval , etr )
417+ record_error (self .output_dir , error = traceback )
418+ result .errored = True
419+ self ._errored = True
420+ raise
421+ finally :
422+ self .hooks .post_run_task (self , result )
423+ self .audit .finalize_audit (result = result )
424+ save (self .output_dir , result = result , task = self )
425+ # removing the additional file with the checksum
426+ (self .cache_dir / f"{ self .uid } _info.json" ).unlink ()
427+ os .chdir (cwd )
428+ self .hooks .post_run (self , result )
429+ # Check for any changes to the input hashes that have occurred during the execution
430+ # of the task
431+ self ._check_for_hash_changes ()
432+ return result
433+
358434 def pickle_task (self ):
359435 """Pickling the tasks with full inputs"""
360436 pkl_files = self .cache_dir / "pkl_files"
@@ -398,7 +474,7 @@ def _combined_output(self, return_inputs=False):
398474 else :
399475 return combined_results
400476
401- def result (self , state_index = None , return_inputs = False ):
477+ def result (self , return_inputs = False ):
402478 """
403479 Retrieve the outcomes of this particular task.
404480
@@ -415,13 +491,9 @@ def result(self, state_index=None, return_inputs=False):
415491 result : Result
416492 the result of the task
417493 """
418- # TODO: check if result is available in load_result and
419- # return a future if not
420494 if self .errored :
421495 return Result (outputs = None , runtime = None , errored = True , task = self )
422496
423- if state_index is not None :
424- raise ValueError ("Task does not have a state" )
425497 checksum = self .checksum
426498 result = load_result (checksum , self .cache_locations )
427499 if result and result .errored :
@@ -483,58 +555,6 @@ def _check_for_hash_changes(self):
483555 DEFAULT_COPY_COLLATION = FileSet .CopyCollation .any
484556
485557
486- class WorkflowTask (Task ):
487-
488- def __init__ (
489- self ,
490- definition : DefType ,
491- submitter : "Submitter" ,
492- name : str ,
493- state_index : "state.StateIndex | None" = None ,
494- ):
495- super ().__init__ (definition , submitter , name , state_index )
496- self .submitter = submitter
497-
498- async def run (self , rerun : bool = False ):
499- self .hooks .pre_run (self )
500- logger .debug (
501- "'%s' is attempting to acquire lock on %s" , self .name , self .lockfile
502- )
503- async with PydraFileLock (self .lockfile ):
504- if not rerun :
505- result = self .result ()
506- if result is not None and not result .errored :
507- return result
508- cwd = os .getcwd ()
509- self ._populate_filesystem ()
510- result = Result (outputs = None , runtime = None , errored = False , task = self )
511- self .hooks .pre_run_task (self )
512- self .audit .start_audit (odir = self .output_dir )
513- try :
514- self .audit .monitor ()
515- await self .submitter .expand_workflow (self )
516- result .outputs = self .definition .Outputs .from_task (self )
517- except Exception :
518- etype , eval , etr = sys .exc_info ()
519- traceback = format_exception (etype , eval , etr )
520- record_error (self .output_dir , error = traceback )
521- result .errored = True
522- self ._errored = True
523- raise
524- finally :
525- self .hooks .post_run_task (self , result )
526- self .audit .finalize_audit (result = result )
527- save (self .output_dir , result = result , task = self )
528- # removing the additional file with the checksum
529- (self .cache_dir / f"{ self .uid } _info.json" ).unlink ()
530- os .chdir (cwd )
531- self .hooks .post_run (self , result )
532- # Check for any changes to the input hashes that have occurred during the execution
533- # of the task
534- self ._check_for_hash_changes ()
535- return result
536-
537-
538558logger = logging .getLogger ("pydra" )
539559
540560OutputsType = ty .TypeVar ("OutputType" , bound = TaskOutputs )
@@ -847,18 +867,12 @@ def create_dotfile(self, type="simple", export=None, name=None, output_dir=None)
847867 return dotfile , formatted_dot
848868
849869
850- def is_task (obj ):
851- """Check whether an object looks like a task."""
852- return hasattr (obj , "_run_task" )
853-
854-
855870def is_workflow (obj ):
856871 """Check whether an object is a :class:`Workflow` instance."""
857872 from pydra .engine .specs import WorkflowDef
858873 from pydra .engine .core import Workflow
859- from pydra .engine .core import WorkflowTask
860874
861- return isinstance (obj , (WorkflowDef , WorkflowTask , Workflow ))
875+ return isinstance (obj , (WorkflowDef , Workflow ))
862876
863877
864878def has_lazy (obj ):
0 commit comments