Skip to content

Commit a8a61aa

Browse files
committed
FEAT(net,.TC): ABORT flag
1 parent 311f025 commit a8a61aa

File tree

3 files changed

+68
-8
lines changed

3 files changed

+68
-8
lines changed

graphtik/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,6 @@
1010
__summary__ = __doc__.splitlines()[0]
1111
__uri__ = "https://github.com/pygraphkit/graphtik"
1212

13-
from .nodes import operation, compose
1413
from .modifiers import * # noqa, on purpose to include any new modifiers
14+
from .network import abort_run, AbortedException
15+
from .nodes import compose, operation

graphtik/network.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,27 @@
8080

8181
log = logging.getLogger(__name__)
8282

83-
thread_pool: ContextVar[Pool] = ContextVar("thread_pool", default=Pool(7))
83+
84+
execution_configs: ContextVar[dict] = ContextVar(
85+
"execution_configs",
86+
default={"thread_pool": Pool(7), "abort": False},
87+
)
88+
89+
90+
class AbortedException(Exception):
91+
pass
92+
93+
94+
def abort_run():
95+
execution_configs.get()["abort"] = True
96+
97+
98+
def _reset_abort():
99+
execution_configs.get()["abort"] = False
100+
101+
102+
def is_abort():
103+
return execution_configs.get()["abort"]
84104

85105

86106
class _DataNode(str):
@@ -195,6 +215,12 @@ def _pin_data_in_solution(self, value_name, solution, inputs, overwrites):
195215
overwrites[value_name] = solution[value_name]
196216
solution[value_name] = inputs[value_name]
197217

218+
def _check_if_aborted(self, executed):
219+
if is_abort():
220+
# Restore `abort` flag for next run.
221+
_reset_abort()
222+
raise AbortedException({s: s in executed for s in self.steps})
223+
198224
def _call_operation(self, op, solution):
199225
# Although `plan` have added to jetsam in `compute()``,
200226
# add it again, in case compile()/execute is called separately.
@@ -217,12 +243,13 @@ def _execute_thread_pool_barrier_method(self, solution, overwrites, executed):
217243
n: solution[n] for n in self.steps if isinstance(n, _PinInstruction)
218244
}
219245

220-
pool = thread_pool.get()
246+
pool = execution_configs.get()["thread_pool"]
221247

222248
# with each loop iteration, we determine a set of operations that can be
223249
# scheduled, then schedule them onto a thread pool, then collect their
224250
# results onto a memory solution for use upon the next iteration.
225251
while True:
252+
self._check_if_aborted(executed)
226253

227254
# the upnext list contains a list of operations for scheduling
228255
# in the current round of scheduling
@@ -293,6 +320,7 @@ def _execute_sequential_method(self, solution, overwrites, executed):
293320

294321
self.times = {}
295322
for step in self.steps:
323+
self._check_if_aborted(executed)
296324

297325
if isinstance(step, Operation):
298326

test/test_graphtik.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,20 @@
1010
import pytest
1111

1212
import graphtik.network as network
13-
from graphtik import compose, operation, optional, sideffect
13+
from graphtik import (
14+
abort_run,
15+
AbortedException,
16+
compose,
17+
operation,
18+
optional,
19+
sideffect,
20+
)
1421
from graphtik.base import Operation
15-
from graphtik.network import _EvictInstruction
22+
23+
24+
@pytest.fixture(params=["sequential", "parallel"])
25+
def exemethod(request):
26+
return request.param
1627

1728

1829
def scream(*args, **kwargs):
@@ -21,8 +32,8 @@ def scream(*args, **kwargs):
2132
)
2233

2334

24-
def identity(x):
25-
return x
35+
def identity(*x):
36+
return x[0] if len(x) == 1 else x
2637

2738

2839
def filtdict(d, *keys):
@@ -856,7 +867,7 @@ def addplusplus(a, b, c=0):
856867
def test_evict_instructions_vary_with_inputs():
857868
# Check #21: _EvictInstructions positions vary when inputs change.
858869
def count_evictions(steps):
859-
return sum(isinstance(n, _EvictInstruction) for n in steps)
870+
return sum(isinstance(n, network._EvictInstruction) for n in steps)
860871

861872
pipeline = compose(name="pipeline")(
862873
operation(name="a free without b", needs=["a"], provides=["aa"])(identity),
@@ -1048,3 +1059,23 @@ def test_compose_another_network(bools):
10481059

10491060
sol = bigger_graph({"a": 2, "b": 5, "c": 5}, outputs=["a_minus_ab_minus_c"])
10501061
assert sol == {"a_minus_ab_minus_c": -13}
1062+
1063+
1064+
def test_abort(exemethod):
1065+
pipeline = compose(name="pipeline")(
1066+
operation(name="A", needs=["a"], provides=["b"])(identity),
1067+
operation(name="B", needs=["b"], provides=["c"])(lambda x: abort_run()),
1068+
operation(name="C", needs=["c"], provides=["d"])(identity),
1069+
)
1070+
pipeline.set_execution_method(exemethod)
1071+
with pytest.raises(AbortedException) as exinfo:
1072+
pipeline({"a": 1})
1073+
assert exinfo.value.jetsam["solution"] == {"a": 1, "b": 1, "c": None}
1074+
executed = {op.name: val for op, val in exinfo.value.args[0].items()}
1075+
assert executed == {"A": True, "B": True, "C": False}
1076+
1077+
pipeline = compose(name="pipeline")(
1078+
operation(name="A", needs=["a"], provides=["b"])(identity)
1079+
)
1080+
pipeline.set_execution_method(exemethod)
1081+
assert pipeline({"a": 1}) == {"a": 1, "b": 1}

0 commit comments

Comments
 (0)