1010
1111from pytensor import config
1212from pytensor .graph .basic import Apply , Constant
13- from pytensor .graph .fg import FunctionGraph
13+ from pytensor .graph .fg import FunctionGraph , Output
1414from pytensor .graph .type import Type
1515from pytensor .link .numba .cache import compile_numba_function_src , hash_from_pickle_dump
1616from pytensor .link .numba .dispatch .sparse import CSCMatrixType , CSRMatrixType
@@ -502,27 +502,41 @@ def numba_funcify_FunctionGraph(
502502 toposort = fgraph .toposort ()
503503 clients = fgraph .clients
504504 toposort_indices = {node : i for i , node in enumerate (toposort )}
505- # Add dummy output clients which are not included of the toposort
505+ # Use -1 for root inputs / constants whose owner is None
506+ toposort_indices [None ] = - 1
507+ # Add dummy output nodes which are not included of the toposort
506508 toposort_indices |= {
507- clients [out ][0 ][0 ]: i
508- for i , out in enumerate (fgraph .outputs , start = len (toposort ))
509+ out_node : i + len (toposort )
510+ for i , out in enumerate (fgraph .outputs )
511+ for out_node , _ in clients [out ]
512+ if isinstance (out_node .op , Output ) and out_node .op .idx == i
509513 }
510514
511- def op_conversion_and_key_collection (* args , ** kwargs ):
515+ def op_conversion_and_key_collection (op , * args , node , ** kwargs ):
512516 # Convert an Op to a funcified function and store the cache_key
513517
514518 # We also Cache each Op so Numba can do less work next time it sees it
515- func , key = numba_funcify_ensure_cache (* args , ** kwargs )
516- cache_keys .append (key )
519+ func , key = numba_funcify_ensure_cache (op , node = node , * args , ** kwargs )
520+ if key is None :
521+ cache_keys .append (key )
522+ else :
523+ # Add graph coordinate information (input edges and node location)
524+ cache_keys .append (
525+ (
526+ toposort_indices [node ],
527+ tuple (toposort_indices [inp .owner ] for inp in node .inputs ),
528+ key ,
529+ )
530+ )
517531 return func
518532
519533 def type_conversion_and_key_collection (value , variable , ** kwargs ):
520534 # Convert a constant type to a numba compatible one and compute a cache key for it
521535
522- # We need to know where in the graph the constants are used
523- # Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
536+ # Add graph coordinate information (client edges)
524537 # FIXME: It doesn't make sense to call type_conversion on non-constants,
525- # but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
538+ # but that's what fgraph_to_python currently does.
539+ # We appease it, but don't consider for caching
526540 if isinstance (variable , Constant ):
527541 client_indices = tuple (
528542 (toposort_indices [node ], inp_idx ) for node , inp_idx in clients [variable ]
@@ -541,8 +555,20 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
541555 # If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542556 fgraph_key = None
543557 else :
558+ # Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges)
559+ # Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python`
560+ fgraph_input_clients = tuple (
561+ tuple ((toposort_indices [node ], inp_idx ) for node , inp_idx in clients [inp ])
562+ for inp in fgraph .inputs
563+ )
564+ fgraph_output_ancestors = tuple (
565+ tuple (toposort_indices [inp .owner ] for inp in out .owner .inputs )
566+ for out in fgraph .outputs
567+ if out .owner is not None # constant outputs
568+ )
569+
544570 # Compose individual cache_keys into a global key for the FunctionGraph
545571 fgraph_key = sha256 (
546- f"({ type (fgraph )} , { tuple (cache_keys )} , { len ( fgraph . inputs ) } , { len ( fgraph . outputs ) } )" .encode ()
572+ f"({ type (fgraph )} , { tuple (cache_keys )} , { fgraph_input_clients } , { fgraph_output_ancestors } )" .encode ()
547573 ).hexdigest ()
548574 return numba_njit (py_func ), fgraph_key
0 commit comments