1010
1111from pytensor import config
1212from pytensor .graph .basic import Apply , Constant
13- from pytensor .graph .fg import FunctionGraph
13+ from pytensor .graph .fg import FunctionGraph , Output
1414from pytensor .graph .type import Type
1515from pytensor .link .numba .cache import compile_numba_function_src , hash_from_pickle_dump
1616from pytensor .link .numba .dispatch .sparse import CSCMatrixType , CSRMatrixType
@@ -501,28 +501,44 @@ def numba_funcify_FunctionGraph(
501501 cache_keys = []
502502 toposort = fgraph .toposort ()
503503 clients = fgraph .clients
504- toposort_indices = {node : i for i , node in enumerate (toposort )}
505- # Add dummy output clients which are not included of the toposort
504+ toposort_indices : dict [Apply | None , int ] = {
505+ node : i for i , node in enumerate (toposort )
506+ }
507+ # Use -1 for root inputs / constants whose owner is None
508+ toposort_indices [None ] = - 1
509+ # Add dummy output nodes which are not included of the toposort
506510 toposort_indices |= {
507- clients [out ][0 ][0 ]: i
508- for i , out in enumerate (fgraph .outputs , start = len (toposort ))
511+ out_node : i + len (toposort )
512+ for i , out in enumerate (fgraph .outputs )
513+ for out_node , _ in clients [out ]
514+ if isinstance (out_node .op , Output ) and out_node .op .idx == i
509515 }
510516
511- def op_conversion_and_key_collection (* args , ** kwargs ):
517+ def op_conversion_and_key_collection (op , * args , node , ** kwargs ):
512518 # Convert an Op to a funcified function and store the cache_key
513519
514520 # We also Cache each Op so Numba can do less work next time it sees it
515- func , key = numba_funcify_ensure_cache (* args , ** kwargs )
516- cache_keys .append (key )
521+ func , key = numba_funcify_ensure_cache (op , node = node , * args , ** kwargs )
522+ if key is None :
523+ cache_keys .append (key )
524+ else :
525+ # Add graph coordinate information (input edges and node location)
526+ cache_keys .append (
527+ (
528+ toposort_indices [node ],
529+ tuple (toposort_indices [inp .owner ] for inp in node .inputs ),
530+ key ,
531+ )
532+ )
517533 return func
518534
519535 def type_conversion_and_key_collection (value , variable , ** kwargs ):
520536 # Convert a constant type to a numba compatible one and compute a cache key for it
521537
522- # We need to know where in the graph the constants are used
523- # Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
538+ # Add graph coordinate information (client edges)
524539 # FIXME: It doesn't make sense to call type_conversion on non-constants,
525- # but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
540+ # but that's what fgraph_to_python currently does.
541+ # We appease it, but don't consider for caching
526542 if isinstance (variable , Constant ):
527543 client_indices = tuple (
528544 (toposort_indices [node ], inp_idx ) for node , inp_idx in clients [variable ]
@@ -541,8 +557,24 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
541557 # If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542558 fgraph_key = None
543559 else :
560+ # Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges)
561+ # Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python`
562+ fgraph_input_clients = tuple (
563+ tuple (
564+ (toposort_indices [node ], inp_idx )
565+ # Disconnect inputs don't have clients
566+ for node , inp_idx in clients .get (inp , ())
567+ )
568+ for inp in fgraph .inputs
569+ )
570+ fgraph_output_ancestors = tuple (
571+ tuple (toposort_indices [inp .owner ] for inp in out .owner .inputs )
572+ for out in fgraph .outputs
573+ if out .owner is not None # constant outputs
574+ )
575+
544576 # Compose individual cache_keys into a global key for the FunctionGraph
545577 fgraph_key = sha256 (
546- f"({ type (fgraph )} , { tuple (cache_keys )} , { len ( fgraph . inputs ) } , { len ( fgraph . outputs ) } )" .encode ()
578+ f"({ type (fgraph )} , { tuple (cache_keys )} , { fgraph_input_clients } , { fgraph_output_ancestors } )" .encode ()
547579 ).hexdigest ()
548580 return numba_njit (py_func ), fgraph_key
0 commit comments