1717 Mapping ,
1818 NamedTuple ,
1919 Optional ,
20+ Sequence ,
2021 Tuple ,
2122 Union ,
2223)
@@ -239,18 +240,23 @@ def _topo_sort_nodes(dag) -> iset:
239240 raise nx .NetworkXUnfeasible (msg ).with_traceback (tb )
240241
241242
242- def recompute_inputs (
243- graph , inputs , recompute_from , recompute_till
243+ def inputs_for_recompute (
244+ graph ,
245+ inputs : Sequence [str ],
246+ recompute_from : Sequence [str ],
247+ recompute_till : Sequence [str ] = None ,
244248) -> Tuple [iset , iset ]:
245249 """
246250 Clears the inputs between `recompute_from >--<= recompute_till` to clear.
247251
252+ :param graph:
253+ MODIFIED, at most 2 helper nodes inserted
248254 :param inputs:
249- None or a sequence
255+ a sequence
250256 :param recompute_from:
251- None or a sequence
257+ None or a sequence, including any out-of-graph deps (logged))
252258 :param recompute_till:
253- (UNSTABLE) None or a sequence
259+ (optional) a sequence, only in-graph deps.
254260
255261 :return:
256262 a 2-tuple with the reduced `inputs` by the dependencies that must
@@ -260,51 +266,40 @@ def recompute_inputs(
260266
261267 strict-descendants(recompute_from) & ancestors(recompute_till)
262268
263- FIXME: Should recompute() while travesing unsatisfied? Is `till` relevant??
269+ FIXME: merge recompute() with travesing unsatisfied (see ``test_recompute_NEEDS_FIX``)
270+ bc it clears inputs of unsatisfied ops (cannot be replaced later)
264271 """
265- start , stop = "_TMP.RECOMPUTE_FROM" , "_TMP.RECOMPUTE_TILL"
266- graph = graph .copy ()
267-
268- datanodes = iset (yield_datanodes (graph .nodes ))
269- downstreams_strict = datanodes
270- if recompute_from is not None :
271- recompute_from = iset (recompute_from ) # traversed in logs
272- bad = recompute_from - datanodes
273- if bad :
274- log .info (
275- "... ignoring unknown `recompute_from` dependencies: %s" , list (bad )
276- )
277- recompute_from = recompute_from & datanodes
278- if recompute_from :
279- graph .add_edges_from ((start , i ) for i in recompute_from )
272+ START , STOP = "_TMP.RECOMPUTE_FROM" , "_TMP.RECOMPUTE_TILL"
280273
281- downstreams_strict = (
282- iset (yield_datanodes (nx .descendants (graph , start ))) - recompute_from
283- )
274+ deps = set (yield_datanodes (graph .nodes ))
275+ recompute_from = iset (recompute_from ) # traversed in logs
276+ inputs = iset (inputs ) # returned
277+ bad = recompute_from - deps
278+ if bad :
279+ log .info ("... ignoring unknown `recompute_from` dependencies: %s" , list (bad ))
280+ recompute_from = recompute_from & deps # avoid sideffect in `recompute_from`
281+ assert recompute_from , f"Given unknown-only `recompute_from` { locals ()} "
284282
285- if recompute_till is None :
286- upstreams = datanodes
287- else :
288- recompute_till = iset (recompute_till ) # traversed in logs
289- bad = recompute_till - datanodes
290- if bad :
291- log .info (
292- "... ignoring unknown `recompute_till` dependencies: %s" , list (bad )
293- )
294- recompute_till = recompute_till & datanodes
295- graph .add_edges_from ((i , stop ) for i in recompute_till ) # edge reversed!
283+ graph .add_edges_from ((START , i ) for i in recompute_from )
284+
285+ # strictly-downstreams from START
286+ between_deps = iset (nx .descendants (graph , START )) & deps - recompute_from
287+
288+ if recompute_till :
289+ graph .add_edges_from ((i , STOP ) for i in recompute_till ) # edge reversed!
296290
297- upstreams = iset (yield_datanodes (nx .ancestors (graph , stop )))
291+ # upstreams from STOP
292+ upstreams = set (nx .ancestors (graph , STOP )) & deps
293+ between_deps &= upstreams
298294
299- between = downstreams_strict & upstreams
300- recomputes = [i for i in inputs if i in between ]
301- new_inputs = iset (inputs ) - between
295+ recomputes = between_deps & inputs
296+ new_inputs = iset (inputs ) - recomputes
302297
303298 if log .isEnabledFor (logging .DEBUG ):
304299 log .debug (
305300 "... recompute x%i data%s means deleting x%i inputs%s, to arrive from x%i %s -> x%i %s." ,
306- len (between ),
307- list (between ),
301+ len (between_deps ),
302+ list (between_deps ),
308303 len (recomputes ),
309304 list (recomputes ),
310305 len (inputs ),
@@ -866,7 +861,6 @@ def compile(
866861 inputs : Items = None ,
867862 outputs : Items = None ,
868863 recompute_from = None ,
869- recompute_till = None ,
870864 * ,
871865 predicate = None ,
872866 ) -> "ExecutionPlan" :
@@ -886,8 +880,6 @@ def compile(
886880 If string, it is converted to a single-element collection.
887881 :param recompute_from:
888882 Described in :meth:`.Pipeline.compute()`.
889- :param recompute_till:
890- (UNSTABLE) Described in :meth:`.Pipeline.compute()`.
891883 :param predicate:
892884 the :term:`node predicate` is a 2-argument callable(op, node-data)
893885 that should return true for nodes to include; if None, all nodes included.
@@ -909,15 +901,15 @@ def compile(
909901
910902 ok = False
911903 try :
912- ## Make a stable cache-key.
904+ ## Make a stable cache-key,
905+ # ignoring out-of-graph nodes (2nd results).
913906 #
914907 inputs , k1 = self ._deps_tuplized (inputs , "inputs" )
915908 outputs , k2 = self ._deps_tuplized (outputs , "outputs" )
916909 recompute_from , k3 = self ._deps_tuplized (recompute_from , "recompute_from" )
917- recompute_till , k4 = self ._deps_tuplized (recompute_till , "recompute_till" )
918910 if not predicate :
919911 predicate = None
920- cache_key = (k1 , k2 , k3 , k4 , predicate , is_skip_evictions ())
912+ cache_key = (k1 , k2 , k3 , predicate , is_skip_evictions ())
921913
922914 ## Build (or retrieve from cache) execution plan
923915 # for the given dep-lists (excluding any unknown node-names).
@@ -926,9 +918,9 @@ def compile(
926918 log .debug ("... compile cache-hit key: %s" , cache_key )
927919 plan = self ._cached_plans [cache_key ]
928920 else :
929- if recompute_from is not None or recompute_till is not None :
930- inputs , recomputes = recompute_inputs (
931- self .graph , inputs , recompute_from , recompute_till
921+ if recompute_from :
922+ inputs , recomputes = inputs_for_recompute (
923+ self .graph . copy () , inputs , recompute_from , k2
932924 )
933925
934926 _prune_results = self ._prune_graph (inputs , outputs , predicate )
0 commit comments