Skip to content

Commit af81d07

Browse files
committed
REFACT(SOL) Callback from Solution (not marshaled)
1 parent 92c6029 commit af81d07

File tree

4 files changed

+79
-60
lines changed

4 files changed

+79
-60
lines changed

docs/source/arch.rst

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -708,22 +708,24 @@ Architecture
708708
value. All of them default to ``None`` (false).
709709

710710
callbacks
711-
x2 user-defined optional callables called before & after each function,
712-
installed on :meth:`.Pipeline.compute()`, that they must have this signature::
711+
x2 optional callables called before/after each `operation` :meth:`.Pipeline.compute()`.
712+
Attention, any errors will abort the pipeline execution.
713713

714-
callbacks(op_cb) -> None
714+
pre-op-callback
715+
Called from solution code before :term:`marshalling`.
716+
A use case would be to validate `solution`, or
717+
:ref:`trigger a breakpoint by some condition <break_with_pre_callback>`.
718+
719+
post-op-callback:
720+
Called after solution have been populated with `operation` results.
721+
A use case would be to validate operation `outputs` and/or
722+
solution after results have been populated.
715723

716-
... where ``op_cb`` is an instance of the :class:`.OpTask` :func:`~collections.namedtuple`.
724+
Callbacks must have this signature::
717725

718-
pre-callback
719-
Called after :term:`marshalling` and before `matching inputs`.
720-
A use case would be to :ref:`a breakpoint triggered by some condition
721-
<break_with_pre_callback>`.
726+
callbacks(op_cb) -> None
722727

723-
post-callback:
724-
Called after :term:`zipping outputs` and before :term:`marshalling` results
725-
back into `solution`.
726-
A use case would be to validate the solution and scream immediately on errors.
728+
... where ``op_cb`` is an instance of the :class:`.OpTask`.
727729

728730
jetsam
729731
When a pipeline or an operation fails, the original exception gets annotated

graphtik/execution.py

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,19 @@ def __init__(
118118
is_layered=None,
119119
):
120120
super().__init__(input_values)
121+
## Make callbacks a 2-tuple with possible None callables.
122+
#
123+
if callable(callbacks):
124+
callbacks = (callbacks,)
125+
elif not callbacks:
126+
callbacks = ()
127+
else:
128+
callbacks = tuple(callbacks)
129+
n_cbs = len(callbacks)
130+
if n_cbs < 2:
131+
callbacks = callbacks + ((None,) * (2 - n_cbs))
121132
self.callbacks = callbacks
133+
122134
is_layered = first_solid(is_layered_solution(), is_layered)
123135

124136
##: By default, disable layers if network contains :term:`jsonp` dependencies.
@@ -466,37 +478,24 @@ class OpTask:
466478
Mimic :class:`concurrent.futures.Future` for :term:`sequential` execution.
467479
468480
This intermediate class is needed to solve pickling issue with process executor.
469-
470-
Use also as argument passed in a :term:`callbacks` callable
471-
before executing each operation, and contains fields to identify
472-
the operation call and results:
473481
"""
474482

475-
__slots__ = ("op", "sol", "solid", "callbacks", "result")
483+
__slots__ = ("op", "sol", "solid", "result")
476484
logname = __name__
477485

478-
def __init__(self, op, sol, solid, callbacks=None, result=UNSET):
486+
def __init__(self, op, sol, solid, result=UNSET):
479487
#: the operation about to be computed.
480488
self.op = op
481489
#: the solution (might be just a plain dict if it has been marshalled).
482490
self.sol = sol
483491
#: the operation identity, needed if `sol` is a plain dict.
484492
self.solid = solid
485-
486-
## Make callbacks a 2-tuple with possible None callables.
487-
#
488-
if callbacks is None:
489-
callbacks = (None, None)
490-
elif callable(callbacks):
491-
callbacks = (callbacks, None)
492-
else:
493-
callbacks = tuple(callbacks)
494-
if len(callbacks) < 2:
495-
callbacks = (*callbacks, None)
496-
self.callbacks = callbacks
497-
#: Initially would :data:`.UNSET`, will be set after execution.
493+
#: Initially would :data:`.UNSET`, will be set after execution
494+
#: with operation's outputs or exception.
498495
self.result = result
499496

497+
# if
498+
500499
def marshalled(self):
501500
import dill
502501

@@ -508,13 +507,8 @@ def __call__(self):
508507
log = logging.getLogger(self.logname)
509508
log.debug("+++ (%s) Executing %s...", self.solid, self)
510509
token = task_context.set(self)
511-
callbacks = self.callbacks
512510
try:
513-
if callbacks[0]:
514-
callbacks[0](self)
515511
self.result = self.op.compute(self.sol)
516-
if callbacks[1]:
517-
callbacks[1](self)
518512
finally:
519513
task_context.reset(token)
520514

@@ -716,7 +710,7 @@ def prep_task(op):
716710
# Mark start time here, to include also marshalling overhead.
717711
solution.elapsed_ms[op] = time.time()
718712

719-
task = OpTask(op, input_values, solution.solid, solution.callbacks)
713+
task = OpTask(op, input_values, solution.solid)
720714
if first_solid(global_marshal, getattr(op, "marshalled", None)):
721715
task = task.marshalled()
722716

@@ -743,7 +737,7 @@ def prep_task(op):
743737

744738
return [prep_task(op) for op in operations]
745739

746-
def _handle_task(self, future, op, solution) -> None:
740+
def _handle_task(self, future: Union[OpTask, "AsyncResult"], op, solution) -> None:
747741
"""Un-dill parallel task results (if marshalled), and update solution / handle failure."""
748742

749743
def elapsed_ms(op):
@@ -752,14 +746,18 @@ def elapsed_ms(op):
752746

753747
return elapsed
754748

749+
result = UNSET
755750
try:
756751
## Reset start time for Sequential tasks
757752
# (bummer, they will miss marshalling overhead).
758753
#
759754
if isinstance(future, OpTask):
760755
solution.elapsed_ms[op] = time.time()
761756

762-
outputs = future.get()
757+
if solution.callbacks[0]:
758+
solution.callbacks[0](future)
759+
760+
outputs = result = future.get()
763761
if isinstance(outputs, bytes):
764762
import dill
765763

@@ -772,6 +770,7 @@ def elapsed_ms(op):
772770
"... (%s) op(%s) completed in %sms.", solution.solid, op.name, elapsed
773771
)
774772
except Exception as ex:
773+
result = ex
775774
is_endured = first_solid(
776775
solution.is_endurance, getattr(op, "endured", None)
777776
)
@@ -801,6 +800,9 @@ def elapsed_ms(op):
801800
# add it again, in case compile()/execute() is called separately.
802801
save_jetsam(ex, locals(), "solution", task="future", plan="self")
803802
raise
803+
finally:
804+
if isinstance(future, OpTask) and solution.callbacks[1]:
805+
solution.callbacks[1](future)
804806

805807
def _execute_thread_pool_barrier_method(self, solution: Solution):
806808
"""
@@ -904,7 +906,7 @@ def _execute_sequential_method(self, solution: Solution):
904906
if step in solution.canceled:
905907
continue
906908

907-
task = OpTask(step, solution, solution.solid, solution.callbacks)
909+
task = OpTask(step, solution, solution.solid)
908910
self._handle_task(task, step, solution)
909911

910912
elif isinstance(step, str):
@@ -927,7 +929,7 @@ def execute(
927929
outputs=None,
928930
*,
929931
name="",
930-
callbacks: Callable[[OpTask], None] = None,
932+
callbacks: Tuple[Callable[[OpTask], None], ...] = None,
931933
solution_class=None,
932934
layered_solution=None,
933935
) -> Solution:
@@ -943,9 +945,10 @@ def execute(
943945
:param name:
944946
name of the pipeline used for logging
945947
:param callbacks:
946-
If given, a 2-tuple with (optional) x2 :term:`callbacks` to call before & after
947-
each operation, with :class:`.OpTask` as argument containing the op & solution.
948-
Less or no elements accepted.
948+
If given, a 2-tuple with (optional) :term:`callbacks` to call
949+
before/after computing operation, with :class:`.OpTask` as argument
950+
containing the op & solution.
951+
Can be one (scalar), less than 2, or nothing/no elements accepted.
949952
:param solution_class:
950953
a custom solution factory to use
951954
:param layered_solution:

graphtik/pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,10 +412,10 @@ def compute(
412412
filter-out nodes before compiling
413413
If not given, those set by a previous call to :meth:`withset()` or cstor are used.
414414
:param callbacks:
415-
If given, a 4-tuple with (optional) x2 :term:`callbacks`,
416-
2 to call before & after each operation, and another 2 before/after batch,
417-
with :class:`.OpTask` as argument containing the op & solution.
418-
One (scalar), less than 4, or no elements accepted.
415+
If given, a 2-tuple with (optional) :term:`callbacks` to call
416+
before/after computing operation, with :class:`.OpTask` as argument
417+
containing the op & solution.
418+
Can be one (scalar), less than 2, or nothing/no elements accepted.
419419
:param solution_class:
420420
a custom solution factory to use
421421
:param layered_solution:

test/test_graphtik.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -945,26 +945,40 @@ def test_rescheduling_NO_RESULT(exemethod):
945945
assert sol.scream_if_incomplete()
946946

947947

948+
@pytest.mark.xfail(
949+
sys.version_info < (3, 8), reason="unittest.nock.call.args is trange in PY37-"
950+
)
948951
def test_pre_callback(quarantine_pipeline, exemethod):
949-
pipeline = compose("covid19", quarantine_pipeline, parallel=exemethod)
950-
called_ops = []
951-
952-
def op_called(op_cb):
953-
assert op_cb.sol["quarantine"]
954-
called_ops.append(op_cb.op.name)
952+
# Cannot import top-level `unittest.mock.call`, due to
953+
# https://bugs.python.org/issue35753
954+
from unittest.mock import MagicMock, call
955+
956+
pipeline = compose("covid19", quarantine_pipeline)
957+
958+
callbacks = [MagicMock() for _ in range(2)]
959+
sol = pipeline.compute({"quarantine": True}, callbacks=callbacks)
960+
961+
cbs_count = [cb.call_count for cb in callbacks]
962+
assert cbs_count == [2, 2]
963+
ops_called = [[call.args[0].op for call in cb.call_args_list] for cb in callbacks]
964+
assert ops_called == [
965+
["get_out_or_stay_home", "read_book"],
966+
["get_out_or_stay_home", "read_book"],
967+
]
968+
results_called = [
969+
[call.args[0].result for call in cb.call_args_list] for cb in callbacks
970+
]
971+
assert results_called == [
972+
[{"time": "1h"}, {"fun": "relaxed", "brain": "popular physics"}],
973+
[{"time": "1h"}, {"fun": "relaxed", "brain": "popular physics"}],
974+
]
955975

956-
sol = pipeline.compute({"quarantine": True}, callbacks=op_called)
957976
assert sol == {
958977
"quarantine": True,
959978
"time": "1h",
960979
"fun": "relaxed",
961980
"brain": "popular physics",
962981
}
963-
if exe_params.marshal == 1:
964-
# Marshaled `called_ops` is not itself :-)
965-
assert called_ops == []
966-
else:
967-
assert called_ops == ["get_out_or_stay_home", "read_book"]
968982

969983

970984
##########

0 commit comments

Comments
 (0)