Skip to content

Commit 378ec7e

Browse files
committed
REFACT(NET): publicise PRUNE() to solve dag & reqs ...
+ ENH: accept null inputs & outputs. + refact: simplify collect_requirements() not to accepts needs/provides. + refact: reorder empty-dag/impossible-out checks. + docs & typings.
1 parent 02adced commit 378ec7e

File tree

2 files changed

+108
-108
lines changed

2 files changed

+108
-108
lines changed

graphtik/network.py

Lines changed: 100 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
They are layed out and connected by repeated calls of
2323
:meth:`~Network.add_OP`.
2424
25-
The computation starts with :meth:`~Network._prune_graph()` extracting
25+
The computation starts with :meth:`~Network.prune()` extracting
2626
a *DAG subgraph* by *pruning* its nodes based on given inputs and
2727
requested outputs in :meth:`~Network.compute()`.
2828
2929
:attr:`ExecutionPlan.dag`
3030
An directed-acyclic-graph containing the *pruned* nodes as build by
31-
:meth:`~Network._prune_graph()`. This pruned subgraph is used to decide
31+
:meth:`~Network.prune()`. This pruned subgraph is used to decide
3232
the :attr:`ExecutionPlan.steps` (below).
3333
The containing :class:`ExecutionPlan.steps` instance is cached
3434
in :attr:`_cached_plans` across runs with inputs/outputs as key.
@@ -92,7 +92,7 @@
9292

9393

9494
class AbortedException(Exception):
95-
pass
95+
"""Raised from the Network code when :func:`abort_run()` is called."""
9696

9797

9898
def abort_run():
@@ -165,6 +165,10 @@ def __repr__(self):
165165
# overwrites = None
166166

167167

168+
def _yield_datanodes(graph):
169+
return (n for n in graph if isinstance(n, _DataNode))
170+
171+
168172
def _optionalized(graph, data):
169173
"""Retain optionality of a `data` node based on all `needs` edges."""
170174
all_optionals = all(e[2] for e in graph.out_edges(data, "optional", False))
@@ -180,48 +184,16 @@ def _optionalized(graph, data):
180184
)
181185

182186

183-
def collect_requirements(
184-
graph, inputs: Optional[Collection] = None, outputs: Optional[Collection] = None
185-
) -> Tuple[iset, iset]:
186-
"""
187-
Collect needs/provides from all `graph` ops, and validate them against inputs/outputs.
188-
189-
- If both `needs` & `provides` are `None`, collected needs/provides
190-
are returned as is (no validation).
191-
- Any collected `inputs` that are optional-needs for all ops,
192-
are returned as such.
193-
194-
:return:
195-
a 2-tuple with the optionalized (`needs`, `provides`), both ordered
196-
197-
:raises ValueError:
198-
If `outputs` asked cannot be produced by the `graph`, with msg:
199-
200-
*Impossible outputs...*
201-
"""
187+
def collect_requirements(graph) -> Tuple[iset, iset]:
188+
"""Collect and split datanodes all `graph` ops in needs/provides."""
202189
operations = [op for op in graph if isinstance(op, Operation)]
203-
all_provides = iset(p for op in operations for p in op.provides)
204-
all_needs = iset(n for op in operations for n in op.needs) - all_provides
205-
206-
if inputs is None:
207-
inputs = all_needs
208-
else:
209-
inputs = astuple(inputs, "inputs", allowed_types=abc.Collection)
210-
unknown = iset(inputs) - all_needs
211-
if unknown:
212-
log.warning("Unused inputs%s for %s!", list(unknown), graph.nodes)
213-
214-
if outputs is None:
215-
outputs = all_provides
216-
else:
217-
outputs = astuple(outputs, "provides", allowed_types=abc.Collection)
218-
unknown = iset(outputs) - all_provides - all_needs
219-
if unknown:
220-
raise ValueError(f"Impossible outputs{list(unknown)} for {graph.nodes}!")
221-
222-
inputs = [_optionalized(graph, n) for n in inputs]
223-
224-
return inputs, outputs
190+
provides = iset(p for op in operations for p in op.provides)
191+
needs = (
192+
iset(_optionalized(graph, n) for op in operations for n in op.needs) - provides
193+
)
194+
# TODO: Unify _DataNode + modifiers to avoid ugly hack `net.collect_requirements()`.
195+
provides = iset(str(n) if not isinstance(n, sideffect) else n for n in provides)
196+
return needs, provides
225197

226198

227199
class ExecutionPlan(
@@ -277,8 +249,8 @@ def _build_pydot(self, **kws):
277249
def __repr__(self):
278250
steps = ["\n +--%s" % s for s in self.steps]
279251
return "ExecutionPlan(needs=%s, provides=%s, steps:%s)" % (
280-
aslist(self.needs, "needs", allowed_types=(list, tuple)),
281-
aslist(self.provides, "provides", allowed_types=(list, tuple)),
252+
aslist(self.needs, "needs"),
253+
aslist(self.provides, "provides"),
282254
"".join(steps),
283255
)
284256

@@ -448,7 +420,7 @@ def execute(self, named_inputs, overwrites=None, method=None):
448420
missing = iset(self.needs) - iset(named_inputs)
449421
if missing:
450422
raise ValueError(
451-
f"Plan needs more inputs{list(missing)}!"
423+
f"Plan needs more inputs: {list(missing)}"
452424
f"\n given inputs: {list(named_inputs)}\n {self}"
453425
)
454426

@@ -489,6 +461,10 @@ class Network(Plotter):
489461
"""
490462
Assemble operations & data into a directed-acyclic-graph (DAG) to run them.
491463
464+
:ivar needs:
465+
the "base", all data-nodes that are not produced by some operation
466+
:ivar provides:
467+
the "base", all data-nodes produced by some operation
492468
"""
493469

494470
def __init__(self, *operations):
@@ -499,7 +475,7 @@ def __init__(self, *operations):
499475
dupes = list(operations)
500476
for i in uniques:
501477
dupes.remove(i)
502-
raise ValueError(f"Operations may only be added once, dupes: {iset(dupes)}")
478+
raise ValueError(f"Operations may only be added once, dupes: {list(dupes)}")
503479

504480
# directed graph of layer instances and data-names defining the net.
505481
graph = self.graph = DiGraph()
@@ -567,7 +543,7 @@ def _collect_unsatisfied_operations(self, dag, inputs: Collection):
567543
all its needs have been accounted, so we can get its satisfaction.
568544
569545
- Their provided outputs are not linked to any data in the dag.
570-
An operation might not have any output link when :meth:`_prune_graph()`
546+
An operation might not have any output link when :meth:`prune()`
571547
has broken them, due to given intermediate inputs.
572548
573549
:param dag:
@@ -612,12 +588,15 @@ def _collect_unsatisfied_operations(self, dag, inputs: Collection):
612588

613589
return unsatisfied
614590

615-
def _prune_graph(self, inputs: Collection, outputs: Collection):
591+
def prune(
592+
self, inputs: Optional[Collection], outputs: Optional[Collection]
593+
) -> Tuple[nx.DiGraph, Collection, Collection, Collection]:
616594
"""
617595
Determines what graph steps need to run to get to the requested
618-
outputs from the provided inputs. :
619-
- Eliminate steps that are not on a path arriving to requested outputs.
620-
- Eliminate unsatisfied operations: partial inputs or no outputs needed.
596+
outputs from the provided inputs:
597+
- Eliminate steps that are not on a path arriving to requested outputs;
598+
- Eliminate unsatisfied operations: partial inputs or no outputs needed;
599+
- consolidate the list of needs & provides.
621600
622601
:param inputs:
623602
The names of all given inputs.
@@ -628,38 +607,57 @@ def _prune_graph(self, inputs: Collection, outputs: Collection):
628607
from the provided inputs.
629608
630609
:return:
631-
the *pruned_dag*
610+
a 4-tuple with the *pruned_dag*, the out-edges of the inputs,
611+
and needs/provides resolved.
632612
633613
:raises ValueError:
634-
If `outputs` asked do not exist in network, with msg:
614+
- if `outputs` asked do not exist in network, with msg:
635615
636616
*Unknown output nodes: ...*
617+
618+
- if `outputs` asked cannot be produced by the `graph`, with msg:
619+
620+
*Impossible outputs...*
637621
"""
622+
# TODO: break cycles here.
638623
dag = self.graph
639624

640-
# Ignore input names that aren't in the graph.
641-
inputs_in_graph = set(dag.nodes) & set(inputs) # unordered, iterated, but ok
625+
if inputs is None and outputs is None:
626+
inputs, outputs = self.needs, self.provides
627+
else:
628+
if inputs is None: # means outputs are non-null ...
629+
# Consider "preliminary" `inputs` any non-output node.
630+
inputs = iset(_yield_datanodes(dag)) - outputs
631+
else:
632+
# Ignore `inputs` not in the graph.
633+
inputs = dag.nodes & inputs
642634

643-
# Scream if some requested outputs aren't in the graph.
644-
unknown_outputs = iset(outputs) - dag.nodes
645-
if unknown_outputs:
646-
raise ValueError(f"Unknown output nodes: {list(unknown_outputs)}")
635+
## Scream on unknown `outputs`.
636+
#
637+
if outputs:
638+
unknown_outputs = iset(outputs) - dag.nodes
639+
if unknown_outputs:
640+
raise ValueError(f"Unknown output nodes: {list(unknown_outputs)}")
641+
642+
assert inputs is not None and not isinstance(inputs, str)
643+
# but outputs may still be null.
647644

648645
broken_dag = dag.copy() # preserve net's graph
646+
broken_edges = set() # unordered, not iterated
649647

650648
# Break the incoming edges to all given inputs.
651649
#
652650
# Nodes producing any given intermediate inputs are unecessary
653651
# (unless they are also used elsewhere).
654652
# To discover which ones to prune, we break their incoming edges
655653
# and they will drop out while collecting ancestors from the outputs.
656-
broken_edges = set() # unordered, not iterated
657-
for given in inputs_in_graph:
654+
#
655+
for given in inputs:
658656
broken_edges.update(broken_dag.in_edges(given))
659657
broken_dag.remove_edges_from(broken_edges)
660658

661659
# Drop stray input values and operations (if any).
662-
if outputs:
660+
if outputs is not None:
663661
# If caller requested specific outputs, we can prune any
664662
# unrelated nodes further up the dag.
665663
ending_in_outputs = set()
@@ -669,20 +667,33 @@ def _prune_graph(self, inputs: Collection, outputs: Collection):
669667
broken_dag = broken_dag.subgraph(ending_in_outputs)
670668

671669
# Prune unsatisfied operations (those with partial inputs or no outputs).
672-
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs_in_graph)
673-
# Clone it so that it is picklable.
670+
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
671+
# Clone it, to modify it.
674672
pruned_dag = dag.subgraph(broken_dag.nodes - unsatisfied).copy()
675673

676674
pruned_dag.remove_nodes_from(list(nx.isolates(pruned_dag)))
677675

678-
assert all(
679-
isinstance(n, (Operation, _DataNode)) for n in pruned_dag
680-
), pruned_dag
676+
inputs = iset(_optionalized(pruned_dag, n) for n in inputs if n in pruned_dag)
677+
if outputs is None:
678+
outputs = iset(
679+
n for n in self.provides if n not in inputs and n in pruned_dag
680+
)
681+
else:
682+
unknown = iset(outputs) - pruned_dag
683+
if unknown:
684+
raise ValueError(
685+
f"Impossible outputs: {list(unknown)}\n for graph: {pruned_dag.nodes}"
686+
f"\n {self}"
687+
)
681688

682-
return pruned_dag, broken_edges
689+
assert all(_yield_datanodes(pruned_dag)), pruned_dag
690+
assert inputs is not None and not isinstance(inputs, str)
691+
assert outputs is not None and not isinstance(outputs, str)
692+
693+
return pruned_dag, broken_edges, tuple(inputs), tuple(outputs)
683694

684695
def _build_execution_steps(
685-
self, pruned_dag, inputs: Collection, outputs: Optional[Collection]
696+
self, pruned_dag, inputs: Optional[Collection], outputs: Optional[Collection]
686697
) -> List:
687698
"""
688699
Create the list of operation-nodes & *instructions* evaluating all
@@ -786,20 +797,20 @@ def add_step_once(step):
786797
return steps
787798

788799
def compile(
789-
self, inputs: Collection = (), outputs: Optional[Collection] = ()
800+
self, inputs: Optional[Collection] = None, outputs: Optional[Collection] = None
790801
) -> ExecutionPlan:
791802
"""
792803
Create or get from cache an execution-plan for the given inputs/outputs.
793804
794-
See :meth:`_prune_graph()` and :meth:`_build_execution_steps()`
805+
See :meth:`prune()` and :meth:`_build_execution_steps()`
795806
for detailed description.
796807
797808
:param inputs:
798-
An iterable with the names of all the given inputs.
809+
A collection with the names of all the given inputs.
810+
If `None``, all inputs that lead to given `outputs` are assumed.
799811
:param outputs:
800-
An iterable or the name of the output name(s).
801-
If missing, requested outputs assumed all graph reachable nodes
802-
from one of the given inputs.
812+
A collection or the name of the output name(s).
813+
If `None``, all reachable nodes from the given `inputs` are assumed.
803814
804815
:return:
805816
the cached or fresh new execution-plan
@@ -817,34 +828,31 @@ def compile(
817828
818829
*Unsolvable graph: ...*
819830
"""
820-
821-
# TODO: smarter null inputs, what if outputs are not None?
822-
if inputs is None:
823-
inputs = self.needs
824-
825831
## Make a stable cache-key.
826832
#
827-
inputs_list = astuple(inputs, "inputs", allowed_types=abc.Collection)
828-
outputs_list = astuple(outputs, "outputs", allowed_types=abc.Collection)
829-
cache_key = (
830-
None if inputs is None else tuple(sorted(inputs_list)),
831-
None if outputs is None else tuple(sorted(outputs_list)),
832-
)
833+
if inputs is not None:
834+
inputs = tuple(
835+
sorted(astuple(inputs, "inputs", allowed_types=abc.Collection))
836+
)
837+
if outputs is not None:
838+
outputs = tuple(
839+
sorted(astuple(outputs, "outputs", allowed_types=abc.Collection))
840+
)
841+
cache_key = (inputs, outputs)
833842

834843
## Build (or retrieve from cache) execution plan
835844
# for the given inputs & outputs.
836845
#
837846
if cache_key in self._cached_plans:
838847
plan = self._cached_plans[cache_key]
839848
else:
840-
pruned_dag, broken_edges = self._prune_graph(inputs_list, outputs_list)
849+
pruned_dag, broken_edges, needs, provides = self.prune(inputs, outputs)
841850
if not pruned_dag:
842851
raise ValueError(
843-
f"Unsolvable graph:\n needs{list(inputs_list)}\n provides{outputs_list}\n {self}"
852+
f"Unsolvable graph:\n needs: {inputs}\n provides: {outputs}\n {self}"
844853
)
845854

846-
steps = self._build_execution_steps(pruned_dag, inputs_list, outputs_list)
847-
needs, provides = collect_requirements(pruned_dag, inputs, outputs)
855+
steps = self._build_execution_steps(pruned_dag, needs, outputs or ())
848856
plan = ExecutionPlan(
849857
self,
850858
needs,

0 commit comments

Comments
 (0)