@@ -118,7 +118,19 @@ def __init__(
118118 is_layered = None ,
119119 ):
120120 super ().__init__ (input_values )
121+ ## Make callbacks a 2-tuple with possible None callables.
122+ #
123+ if callable (callbacks ):
124+ callbacks = (callbacks ,)
125+ elif not callbacks :
126+ callbacks = ()
127+ else :
128+ callbacks = tuple (callbacks )
129+ n_cbs = len (callbacks )
130+ if n_cbs < 2 :
131+ callbacks = callbacks + ((None ,) * (2 - n_cbs ))
121132 self .callbacks = callbacks
133+
122134 is_layered = first_solid (is_layered_solution (), is_layered )
123135
124136 ##: By default, disable layers if network contains :term:`jsonp` dependencies.
@@ -466,37 +478,24 @@ class OpTask:
466478 Mimic :class:`concurrent.futures.Future` for :term:`sequential` execution.
467479
468480 This intermediate class is needed to solve pickling issue with process executor.
469-
470- Use also as argument passed in a :term:`callbacks` callable
471- before executing each operation, and contains fields to identify
472- the operation call and results:
473481 """
474482
475- __slots__ = ("op" , "sol" , "solid" , "callbacks" , " result" )
483+ __slots__ = ("op" , "sol" , "solid" , "result" )
476484 logname = __name__
477485
478- def __init__ (self , op , sol , solid , callbacks = None , result = UNSET ):
486+ def __init__ (self , op , sol , solid , result = UNSET ):
479487 #: the operation about to be computed.
480488 self .op = op
481489 #: the solution (might be just a plain dict if it has been marshalled).
482490 self .sol = sol
483491 #: the operation identity, needed if `sol` is a plain dict.
484492 self .solid = solid
485-
486- ## Make callbacks a 2-tuple with possible None callables.
487- #
488- if callbacks is None :
489- callbacks = (None , None )
490- elif callable (callbacks ):
491- callbacks = (callbacks , None )
492- else :
493- callbacks = tuple (callbacks )
494- if len (callbacks ) < 2 :
495- callbacks = (* callbacks , None )
496- self .callbacks = callbacks
497- #: Initially would :data:`.UNSET`, will be set after execution.
493+ #: Initially would :data:`.UNSET`, will be set after execution
494+ #: with operation's outputs or exception.
498495 self .result = result
499496
497+ # if
498+
500499 def marshalled (self ):
501500 import dill
502501
@@ -508,13 +507,8 @@ def __call__(self):
508507 log = logging .getLogger (self .logname )
509508 log .debug ("+++ (%s) Executing %s..." , self .solid , self )
510509 token = task_context .set (self )
511- callbacks = self .callbacks
512510 try :
513- if callbacks [0 ]:
514- callbacks [0 ](self )
515511 self .result = self .op .compute (self .sol )
516- if callbacks [1 ]:
517- callbacks [1 ](self )
518512 finally :
519513 task_context .reset (token )
520514
@@ -716,7 +710,7 @@ def prep_task(op):
716710 # Mark start time here, to include also marshalling overhead.
717711 solution .elapsed_ms [op ] = time .time ()
718712
719- task = OpTask (op , input_values , solution .solid , solution . callbacks )
713+ task = OpTask (op , input_values , solution .solid )
720714 if first_solid (global_marshal , getattr (op , "marshalled" , None )):
721715 task = task .marshalled ()
722716
@@ -743,7 +737,7 @@ def prep_task(op):
743737
744738 return [prep_task (op ) for op in operations ]
745739
746- def _handle_task (self , future , op , solution ) -> None :
740+ def _handle_task (self , future : Union [ OpTask , "AsyncResult" ] , op , solution ) -> None :
747741 """Un-dill parallel task results (if marshalled), and update solution / handle failure."""
748742
749743 def elapsed_ms (op ):
@@ -752,14 +746,18 @@ def elapsed_ms(op):
752746
753747 return elapsed
754748
749+ result = UNSET
755750 try :
756751 ## Reset start time for Sequential tasks
757752 # (bummer, they will miss marshalling overhead).
758753 #
759754 if isinstance (future , OpTask ):
760755 solution .elapsed_ms [op ] = time .time ()
761756
762- outputs = future .get ()
757+ if solution .callbacks [0 ]:
758+ solution .callbacks [0 ](future )
759+
760+ outputs = result = future .get ()
763761 if isinstance (outputs , bytes ):
764762 import dill
765763
@@ -772,6 +770,7 @@ def elapsed_ms(op):
772770 "... (%s) op(%s) completed in %sms." , solution .solid , op .name , elapsed
773771 )
774772 except Exception as ex :
773+ result = ex
775774 is_endured = first_solid (
776775 solution .is_endurance , getattr (op , "endured" , None )
777776 )
@@ -801,6 +800,9 @@ def elapsed_ms(op):
801800 # add it again, in case compile()/execute() is called separately.
802801 save_jetsam (ex , locals (), "solution" , task = "future" , plan = "self" )
803802 raise
803+ finally :
804+ if isinstance (future , OpTask ) and solution .callbacks [1 ]:
805+ solution .callbacks [1 ](future )
804806
805807 def _execute_thread_pool_barrier_method (self , solution : Solution ):
806808 """
@@ -904,7 +906,7 @@ def _execute_sequential_method(self, solution: Solution):
904906 if step in solution .canceled :
905907 continue
906908
907- task = OpTask (step , solution , solution .solid , solution . callbacks )
909+ task = OpTask (step , solution , solution .solid )
908910 self ._handle_task (task , step , solution )
909911
910912 elif isinstance (step , str ):
@@ -927,7 +929,7 @@ def execute(
927929 outputs = None ,
928930 * ,
929931 name = "" ,
930- callbacks : Callable [[OpTask ], None ] = None ,
932+ callbacks : Tuple [ Callable [[OpTask ], None ], ... ] = None ,
931933 solution_class = None ,
932934 layered_solution = None ,
933935 ) -> Solution :
@@ -943,9 +945,10 @@ def execute(
943945 :param name:
944946 name of the pipeline used for logging
945947 :param callbacks:
946- If given, a 2-tuple with (optional) x2 :term:`callbacks` to call before & after
947- each operation, with :class:`.OpTask` as argument containing the op & solution.
948- Less or no elements accepted.
948+ If given, a 2-tuple with (optional) :term:`callbacks` to call
949+ before/after computing operation, with :class:`.OpTask` as argument
950+ containing the op & solution.
951+ Can be one (scalar), less than 2, or nothing/no elements accepted.
949952 :param solution_class:
950953 a custom solution factory to use
951954 :param layered_solution:
0 commit comments