@@ -346,16 +346,19 @@ def find_producer(self, tensor_name):
346346 return x
347347 return None
348348
349- def find_upstream (self , tensor_name , finder_fxn ):
349+ def find_upstream (self , tensor_name , finder_fxn , keep_if_not_found = False ):
350350 """Follow the producer chain upstream, calling finder_fxn on each upstream
351351 node until it returns True or there are no nodes left. Returns the list
352- of nodes visited, or None if finder_fxn did not return True."""
352+ of nodes visited, or None if finder_fxn did not return True. If
353+ keep_if_not_found is specified, returns the list of nodes visited, even
354+ if finder_fxn never returned True, i.e., if the search terminated at an
355+ input or initializer."""
353356 visit_list = []
354357 current_tensor = tensor_name
355358 while True :
356359 current_producer = self .find_producer (current_tensor )
357360 if current_producer is None :
358- return []
361+ return visit_list if keep_if_not_found else []
359362 else :
360363 found = finder_fxn (current_producer )
361364 visit_list .append (current_producer )
@@ -364,7 +367,7 @@ def find_upstream(self, tensor_name, finder_fxn):
364367 elif len (current_producer .input ) > 0 :
365368 current_tensor = current_producer .input [0 ]
366369 else :
367- return None
370+ return visit_list if keep_if_not_found else None
368371
369372 def find_consumer (self , tensor_name ):
370373 """Finds and returns the node that consumes the tensor with given name.
0 commit comments