Skip to content

Commit f8bd89f

Browse files
committed
Fix numba FunctionGraph cache key
It's necessary to encode the edge information, not only the nodes and their ordering
1 parent 4e4f237 commit f8bd89f

File tree

2 files changed

+90
-15
lines changed

2 files changed

+90
-15
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pytensor import config
1212
from pytensor.graph.basic import Apply, Constant
13-
from pytensor.graph.fg import FunctionGraph
13+
from pytensor.graph.fg import FunctionGraph, Output
1414
from pytensor.graph.type import Type
1515
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
1616
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
@@ -502,27 +502,41 @@ def numba_funcify_FunctionGraph(
502502
toposort = fgraph.toposort()
503503
clients = fgraph.clients
504504
toposort_indices = {node: i for i, node in enumerate(toposort)}
505-
# Add dummy output clients which are not included of the toposort
505+
# Use -1 for root inputs / constants whose owner is None
506+
toposort_indices[None] = -1
507+
# Add dummy output nodes which are not included of the toposort
506508
toposort_indices |= {
507-
clients[out][0][0]: i
508-
for i, out in enumerate(fgraph.outputs, start=len(toposort))
509+
out_node: i + len(toposort)
510+
for i, out in enumerate(fgraph.outputs)
511+
for out_node, _ in clients[out]
512+
if isinstance(out_node.op, Output) and out_node.op.idx == i
509513
}
510514

511-
def op_conversion_and_key_collection(*args, **kwargs):
515+
def op_conversion_and_key_collection(op, *args, node, **kwargs):
512516
# Convert an Op to a funcified function and store the cache_key
513517

514518
# We also Cache each Op so Numba can do less work next time it sees it
515-
func, key = numba_funcify_ensure_cache(*args, **kwargs)
516-
cache_keys.append(key)
519+
func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
520+
if key is None:
521+
cache_keys.append(key)
522+
else:
523+
# Add graph coordinate information (input edges and node location)
524+
cache_keys.append(
525+
(
526+
toposort_indices[node],
527+
tuple(toposort_indices[inp.owner] for inp in node.inputs),
528+
key,
529+
)
530+
)
517531
return func
518532

519533
def type_conversion_and_key_collection(value, variable, **kwargs):
520534
# Convert a constant type to a numba compatible one and compute a cache key for it
521535

522-
# We need to know where in the graph the constants are used
523-
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
536+
# Add graph coordinate information (client edges)
524537
# FIXME: It doesn't make sense to call type_conversion on non-constants,
525-
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
538+
# but that's what fgraph_to_python currently does.
539+
# We appease it, but don't consider for caching
526540
if isinstance(variable, Constant):
527541
client_indices = tuple(
528542
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
@@ -541,8 +555,20 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
541555
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542556
fgraph_key = None
543557
else:
558+
# Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges)
559+
# Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python`
560+
fgraph_input_clients = tuple(
561+
tuple((toposort_indices[node], inp_idx) for node, inp_idx in clients[inp])
562+
for inp in fgraph.inputs
563+
)
564+
fgraph_output_ancestors = tuple(
565+
tuple(toposort_indices[inp.owner] for inp in out.owner.inputs)
566+
for out in fgraph.outputs
567+
if out.owner is not None # constant outputs
568+
)
569+
544570
# Compose individual cache_keys into a global key for the FunctionGraph
545571
fgraph_key = sha256(
546-
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
572+
f"({type(fgraph)}, {tuple(cache_keys)}, {fgraph_input_clients}, {fgraph_output_ancestors})".encode()
547573
).hexdigest()
548574
return numba_njit(py_func), fgraph_key

tests/link/numba/test_basic.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,29 @@
77
import pytest
88
import scipy
99

10-
from pytensor.compile import SymbolicInput
11-
from pytensor.tensor.utils import hash_from_ndarray
12-
1310

1411
numba = pytest.importorskip("numba")
1512

1613
import pytensor.scalar as ps
1714
import pytensor.tensor as pt
1815
from pytensor import config, shared
16+
from pytensor.compile import SymbolicInput
1917
from pytensor.compile.function import function
2018
from pytensor.compile.mode import Mode
2119
from pytensor.graph.basic import Apply, Variable
20+
from pytensor.graph.fg import FunctionGraph
2221
from pytensor.graph.op import Op
2322
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2423
from pytensor.graph.type import Type
2524
from pytensor.link.numba.dispatch import basic as numba_basic
26-
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
25+
from pytensor.link.numba.dispatch.basic import (
26+
cache_key_for_constant,
27+
numba_funcify_and_cache_key,
28+
)
2729
from pytensor.link.numba.linker import NumbaLinker
2830
from pytensor.scalar.basic import ScalarOp, as_scalar
2931
from pytensor.tensor.elemwise import Elemwise
32+
from pytensor.tensor.utils import hash_from_ndarray
3033

3134

3235
if TYPE_CHECKING:
@@ -652,3 +655,49 @@ def impl(x):
652655
outs[2].owner.op, outs[2].owner
653656
)
654657
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
658+
659+
660+
def test_fgraph_cache_key():
661+
x = pt.scalar("x")
662+
log_x = pt.log(x)
663+
graphs = [
664+
pt.exp(x) / log_x,
665+
log_x / pt.exp(x),
666+
pt.exp(log_x) / x,
667+
x / pt.exp(log_x),
668+
pt.exp(log_x) / log_x,
669+
log_x / pt.exp(log_x),
670+
]
671+
672+
def generate_and_validate_key(fg):
673+
_, key = numba_funcify_and_cache_key(fg)
674+
assert key is not None
675+
_, key_again = numba_funcify_and_cache_key(fg)
676+
assert key == key_again # Check its stable
677+
return key
678+
679+
keys = []
680+
for graph in graphs:
681+
fg = FunctionGraph([x], [graph], clone=False)
682+
keys.append(generate_and_validate_key(fg))
683+
# Check keys are unique
684+
assert len(set(keys)) == len(graphs)
685+
686+
# Extra unused input should alter the key, because it changes the function signature
687+
y = pt.scalar("y")
688+
for inputs in [[x, y], [y, x]]:
689+
fg = FunctionGraph(inputs, [graphs[0]], clone=False)
690+
keys.append(generate_and_validate_key(fg))
691+
assert len(set(keys)) == len(graphs) + 2
692+
693+
# Adding an input as an output should also change the key
694+
for outputs in [
695+
[graphs[0], x],
696+
[x, graphs[0]],
697+
[x, x, graphs[0]],
698+
[x, graphs[0], x],
699+
[graphs[0], x, x],
700+
]:
701+
fg = FunctionGraph([x], outputs, clone=False)
702+
keys.append(generate_and_validate_key(fg))
703+
assert len(set(keys)) == len(graphs) + 2 + 5

0 commit comments

Comments
 (0)