Skip to content

Commit 3c5523f

Browse files
committed
FEAT(#1,net,netop): PRUNE by node-PROPS
e.g. assign "colors" to nodes, and solve a subset each time.
1 parent 9dc00e7 commit 3c5523f

File tree

3 files changed

+104
-43
lines changed

3 files changed

+104
-43
lines changed

graphtik/netop.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
import re
77
from collections import abc
8-
from typing import Collection
8+
from typing import Any, Callable, Collection, Mapping
99

1010
import networkx as nx
1111
from boltons.setutils import IndexedSet as iset
@@ -44,6 +44,7 @@ def __init__(
4444
*,
4545
inputs=None,
4646
outputs=None,
47+
predicate: Callable[[Any, Mapping], bool] = None,
4748
method=None,
4849
overwrites_collector=None,
4950
):
@@ -52,6 +53,8 @@ def __init__(
5253
see :meth:`narrow()`
5354
:param outputs:
5455
see :meth:`narrow()`
56+
:param predicate:
57+
a 2-argument callable(op, node-data) that should return true for nodes to include
5558
:param method:
5659
either `parallel` or None (default);
5760
if ``"parallel"``, launches multi-threading.
@@ -64,19 +67,19 @@ def __init__(
6467
:raises ValueError:
6568
see :meth:`narrow()`
6669
"""
70+
## Set data asap, for debugging, although `pruned()` will reset them.
6771
self.name = name
6872
self.inputs = inputs
6973
self.provides = outputs
70-
# Prune network
71-
self.net = net.pruned(inputs, outputs)
72-
## Set data asap, for debugging, although `prune()` will reset them.
7374
self.set_execution_method(method)
7475
self.set_overwrites_collector(overwrites_collector)
7576

7677
# TODO: Is it really necessary to sroe IO on netop?
7778
self.inputs = inputs
7879
self.outputs = outputs
7980

81+
# Prune network
82+
self.net = net.pruned(inputs, outputs, predicate)
8083
self.name, self.needs, self.provides = reparse_operation_data(
8184
self.name, self.net.needs, self.net.provides
8285
)
@@ -94,7 +97,11 @@ def __repr__(self):
9497
)
9598

9699
def narrow(
97-
self, inputs: Collection = None, outputs: Collection = None, name=None
100+
self,
101+
inputs: Collection = None,
102+
outputs: Collection = None,
103+
name=None,
104+
predicate: Callable[[Any, Mapping], bool] = None,
98105
) -> "NetworkOperation":
99106
"""
100107
Return a copy with a network pruned for the given `needs` & `provides`.
@@ -118,6 +125,8 @@ def narrow(
118125
<old-name>-<uid>
119126
120127
- otherwise, the given `name` is applied.
128+
:param predicate:
129+
a 2-argument callable(op, node-data) that should return true for nodes to include
121130
122131
:return:
123132
A narrowed netop clone, which **MIGHT be empty!***
@@ -146,6 +155,7 @@ def narrow(
146155
name,
147156
inputs=inputs,
148157
outputs=outputs,
158+
predicate=predicate,
149159
method=self.method,
150160
overwrites_collector=self.overwrites_collector,
151161
)
@@ -316,7 +326,7 @@ def proc_op(op, parent=None):
316326
## Convey any node-props specified in the netop here
317327
# to all sub-operations.
318328
#
319-
if node_props or parent:
329+
if node_props or (not merge and parent):
320330
kw = {}
321331
if node_props:
322332
op_node_props = op.node_props.copy()

graphtik/network.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,17 @@
7272
from collections import abc, defaultdict, namedtuple
7373
from contextvars import ContextVar
7474
from multiprocessing.dummy import Pool
75-
from typing import Collection, Iterable, List, Optional, Tuple, Union
75+
from typing import (
76+
Any,
77+
Callable,
78+
Collection,
79+
Iterable,
80+
List,
81+
Mapping,
82+
Optional,
83+
Tuple,
84+
Union,
85+
)
7686

7787
import networkx as nx
7888
from boltons.setutils import IndexedSet as iset
@@ -650,8 +660,23 @@ def _unsatisfied_operations(self, dag, inputs: Collection):
650660

651661
return unsatisfied
652662

663+
def _apply_graph_predicate(self, graph, predicate):
664+
to_del = []
665+
for node, data in graph.nodes.items():
666+
try:
667+
if isinstance(node, Operation) and not predicate(node, data):
668+
to_del.append(node)
669+
except Exception as ex:
670+
raise ValueError(
671+
f"Node-predicate({predicate}) failed due to: {ex}\n node: {node}, {self}"
672+
) from ex
673+
graph.remove_nodes_from(to_del)
674+
653675
def _prune_graph(
654-
self, inputs: Optional[Collection], outputs: Optional[Collection]
676+
self,
677+
inputs: Optional[Collection],
678+
outputs: Optional[Collection],
679+
predicate: Callable[[Any, Mapping], bool] = None,
655680
) -> Tuple[nx.DiGraph, Collection, Collection, Collection]:
656681
"""
657682
Determines what graph steps need to run to get to the requested
@@ -667,6 +692,8 @@ def _prune_graph(
667692
The desired output names. This can also be ``None``, in which
668693
case the necessary steps are all graph nodes that are reachable
669694
from the provided inputs.
695+
:param predicate:
696+
a 2-argument callable(op, node-data) that should return true for nodes to include
670697
671698
:return:
672699
a 4-tuple with the *pruned_dag*, the out-edges of the inputs,
@@ -719,6 +746,9 @@ def _prune_graph(
719746
broken_dag = dag.copy() # preserve net's graph
720747
broken_edges = set() # unordered, not iterated
721748

749+
if predicate:
750+
self._apply_graph_predicate(broken_dag, predicate)
751+
722752
# Break the incoming edges to all given inputs.
723753
#
724754
# Nodes producing any given intermediate inputs are unecessary
@@ -762,7 +792,10 @@ def _prune_graph(
762792
return pruned_dag, broken_edges, tuple(inputs), tuple(outputs)
763793

764794
def pruned(
765-
self, inputs: Collection = None, outputs: Collection = None
795+
self,
796+
inputs: Collection = None,
797+
outputs: Collection = None,
798+
predicate: Callable[[Any, Mapping], bool] = None,
766799
) -> "Network":
767800
"""
768801
Return a pruned network supporting just the given `inputs` & `outputs`.
@@ -771,19 +804,23 @@ def pruned(
771804
all possible inputs names
772805
:param outputs:
773806
all possible output names
807+
:param predicate:
808+
a 2-argument callable(op, node-data) that should return true for nodes to include
774809
775810
:return:
776811
the pruned clone, or this, if both `inputs` & `outputs` were `None`
777812
"""
778-
if inputs is None and outputs is None:
813+
if (inputs, outputs, predicate) == (None, None, None):
779814
return self
780815

781816
if inputs is not None:
782817
inputs = astuple(inputs, "outputs", allowed_types=(list, tuple))
783818
if outputs is not None:
784819
outputs = astuple(outputs, "outputs", allowed_types=(list, tuple))
785820

786-
pruned_dag, _br_edges, _needs, _provides = self._prune_graph(inputs, outputs)
821+
pruned_dag, _br_edges, _needs, _provides = self._prune_graph(
822+
inputs, outputs, predicate
823+
)
787824
return Network(graph=pruned_dag)
788825

789826
def _build_execution_steps(

test/test_graphtik.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def filtdict(d, *keys):
5050
return type(d)(i for i in d.items() if i[0] in keys)
5151

5252

53-
def addall(*a):
53+
def addall(*a, **kw):
5454
"Same as a + b + ...."
55-
return sum(a)
55+
return sum(a) + sum(kw.values())
5656

5757

5858
def abspow(a, p):
@@ -321,13 +321,48 @@ def test_network_merge_in_doctests():
321321
assert merged_graph.provides
322322

323323
assert (
324-
repr(merged_graph) ==
325-
"NetworkOperation('merged_graph', "
324+
repr(merged_graph) == "NetworkOperation('merged_graph', "
326325
"needs=['a', 'b', 'ab', 'a_minus_ab', 'c'], "
327326
"provides=['ab', 'a_minus_ab', 'abs_a_minus_ab_cubed', 'cab'], x4ops)"
328327
)
329328

330329

330+
@pytest.fixture
331+
def samplenet():
332+
# Set up a network such that we don't need to provide a or b d if we only
333+
# request sum3 as output and if we provide sum2.
334+
sum_op1 = operation(name="sum_op1", needs=["a", "b"], provides="sum1")(add)
335+
sum_op2 = operation(name="sum_op2", needs=["c", "d"], provides="sum2")(add)
336+
sum_op3 = operation(name="sum_op3", needs=["c", "sum2"], provides="sum3")(add)
337+
return compose("test_net", sum_op1, sum_op2, sum_op3)
338+
339+
340+
def test_node_props_based_prune():
341+
netop = compose(
342+
"N",
343+
operation(name="A", needs=["a"], provides=["aa"], node_props={"color": "red"})(
344+
identity
345+
),
346+
operation(
347+
name="B", needs=["b"], provides=["bb"], node_props={"color": "green"}
348+
)(identity),
349+
operation(name="C", needs=["c"], provides=["cc"])(identity),
350+
operation(
351+
name="SUM",
352+
needs=[optional(i) for i in ("aa", "bb", "cc")],
353+
provides=["sum"],
354+
)(addall),
355+
)
356+
inp = {"a": 1, "b": 2, "c": 3}
357+
# assert netop(**inp)["sum"] == 6
358+
359+
pred = lambda n, d: d.get("color", None) != "red"
360+
assert netop.narrow(predicate=pred)(**inp)["sum"] == 5
361+
362+
pred = lambda n, d: "color" not in d
363+
assert netop.narrow(predicate=pred)(**inp)["sum"] == 3
364+
365+
331366
def test_input_based_pruning():
332367
# Tests to make sure we don't need to pass graph inputs if we're provided
333368
# with data further downstream in the graph as an input.
@@ -349,72 +384,51 @@ def test_input_based_pruning():
349384
assert results["sum3"] == add(sum1, sum2)
350385

351386

352-
def test_output_based_pruning():
387+
def test_output_based_pruning(samplenet):
353388
# Tests to make sure we don't need to pass graph inputs if they're not
354389
# needed to compute the requested outputs.
355390

356391
c = 2
357392
d = 3
358393

359-
# Set up a network such that we don't need to provide a or b if we only
360-
# request sum3 as output.
361-
sum_op1 = operation(name="sum_op1", needs=["a", "b"], provides="sum1")(add)
362-
sum_op2 = operation(name="sum_op2", needs=["c", "d"], provides="sum2")(add)
363-
sum_op3 = operation(name="sum_op3", needs=["c", "sum2"], provides="sum3")(add)
364-
net = compose("test_net", sum_op1, sum_op2, sum_op3)
365-
366-
results = net.compute({"a": 0, "b": 0, "c": c, "d": d}, ["sum3"])
394+
results = samplenet.compute({"a": 0, "b": 0, "c": c, "d": d}, ["sum3"])
367395

368396
# Make sure we got expected result without having to pass a or b.
369397
assert "sum3" in results
370398
assert results["sum3"] == add(c, add(c, d))
371399

372400

373-
def test_deps_pruning_vs_narrowing():
401+
def test_deps_pruning_vs_narrowing(samplenet):
374402
# Tests to make sure we don't need to pass graph inputs if they're not
375403
# needed to compute the requested outputs or of we're provided with
376404
# inputs that are further downstream in the graph.
377405

378406
c = 2
379407
sum2 = 5
380408

381-
# Set up a network such that we don't need to provide a or b d if we only
382-
# request sum3 as output and if we provide sum2.
383-
sum_op1 = operation(name="sum_op1", needs=["a", "b"], provides="sum1")(add)
384-
sum_op2 = operation(name="sum_op2", needs=["c", "d"], provides="sum2")(add)
385-
sum_op3 = operation(name="sum_op3", needs=["c", "sum2"], provides="sum3")(add)
386-
net = compose("test_net", sum_op1, sum_op2, sum_op3)
387-
388-
results = net.compute({"c": c, "sum2": sum2}, ["sum3"])
409+
results = samplenet.compute({"c": c, "sum2": sum2}, ["sum3"])
389410

390411
# Make sure we got expected result without having to pass a, b, or d.
391412
assert "sum3" in results
392413
assert results["sum3"] == add(c, sum2)
393414

394415
# Compare with both `narrow()`.
395-
net = net.narrow(inputs=["c", "sum2"], outputs=["sum3"])
416+
net = samplenet.narrow(inputs=["c", "sum2"], outputs=["sum3"])
396417
results = net(c=c, sum2=sum2)
397418

398419
# Make sure we got expected result without having to pass a, b, or d.
399420
assert "sum3" in results
400421
assert results["sum3"] == add(c, sum2)
401422

402423

403-
def test_pruning_raises_for_bad_output():
424+
def test_pruning_raises_for_bad_output(samplenet):
404425
# Make sure we get a ValueError during the pruning step if we request an
405426
# output that doesn't exist.
406427

407-
# Set up a network that doesn't have the output sum4, which we'll request
408-
# later.
409-
sum_op1 = operation(name="sum_op1", needs=["a", "b"], provides="sum1")(add)
410-
sum_op2 = operation(name="sum_op2", needs=["c", "d"], provides="sum2")(add)
411-
sum_op3 = operation(name="sum_op3", needs=["c", "sum2"], provides="sum3")(add)
412-
net = compose("test_net", sum_op1, sum_op2, sum_op3)
413-
414428
# Request two outputs we can compute and one we can't compute. Assert
415429
# that this raises a ValueError.
416430
with pytest.raises(ValueError) as exinfo:
417-
net.compute({"a": 1, "b": 2, "c": 3, "d": 4}, ["sum1", "sum3", "sum4"])
431+
samplenet.compute({"a": 1, "b": 2, "c": 3, "d": 4}, ["sum1", "sum3", "sum4"])
418432
assert exinfo.match("sum4")
419433

420434

0 commit comments

Comments
 (0)