Skip to content

Commit 6add5af

Browse files
committed
Fix numba FunctionGraph cache key
It's necessary to encode the edge information, not only the nodes and their ordering
1 parent c1b2011 commit 6add5af

File tree

2 files changed

+97
-16
lines changed

2 files changed

+97
-16
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from pytensor import config
1212
from pytensor.graph.basic import Apply, Constant
13-
from pytensor.graph.fg import FunctionGraph
13+
from pytensor.graph.fg import FunctionGraph, Output
1414
from pytensor.graph.type import Type
1515
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
1616
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
@@ -501,28 +501,44 @@ def numba_funcify_FunctionGraph(
501501
cache_keys = []
502502
toposort = fgraph.toposort()
503503
clients = fgraph.clients
504-
toposort_indices = {node: i for i, node in enumerate(toposort)}
505-
# Add dummy output clients which are not included of the toposort
504+
toposort_indices: dict[Apply | None, int] = {
505+
node: i for i, node in enumerate(toposort)
506+
}
507+
# Use -1 for root inputs / constants whose owner is None
508+
toposort_indices[None] = -1
509+
# Add dummy output nodes which are not included of the toposort
506510
toposort_indices |= {
507-
clients[out][0][0]: i
508-
for i, out in enumerate(fgraph.outputs, start=len(toposort))
511+
out_node: i + len(toposort)
512+
for i, out in enumerate(fgraph.outputs)
513+
for out_node, _ in clients[out]
514+
if isinstance(out_node.op, Output) and out_node.op.idx == i
509515
}
510516

511-
def op_conversion_and_key_collection(*args, **kwargs):
517+
def op_conversion_and_key_collection(op, *args, node, **kwargs):
512518
# Convert an Op to a funcified function and store the cache_key
513519

514520
# We also Cache each Op so Numba can do less work next time it sees it
515-
func, key = numba_funcify_ensure_cache(*args, **kwargs)
516-
cache_keys.append(key)
521+
func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
522+
if key is None:
523+
cache_keys.append(key)
524+
else:
525+
# Add graph coordinate information (input edges and node location)
526+
cache_keys.append(
527+
(
528+
toposort_indices[node],
529+
tuple(toposort_indices[inp.owner] for inp in node.inputs),
530+
key,
531+
)
532+
)
517533
return func
518534

519535
def type_conversion_and_key_collection(value, variable, **kwargs):
520536
# Convert a constant type to a numba compatible one and compute a cache key for it
521537

522-
# We need to know where in the graph the constants are used
523-
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
538+
# Add graph coordinate information (client edges)
524539
# FIXME: It doesn't make sense to call type_conversion on non-constants,
525-
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
540+
# but that's what fgraph_to_python currently does.
541+
# We appease it, but don't consider for caching
526542
if isinstance(variable, Constant):
527543
client_indices = tuple(
528544
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
@@ -541,8 +557,24 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
541557
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
542558
fgraph_key = None
543559
else:
560+
# Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges)
561+
# Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python`
562+
fgraph_input_clients = tuple(
563+
tuple(
564+
(toposort_indices[node], inp_idx)
565+
# Disconnect inputs don't have clients
566+
for node, inp_idx in clients.get(inp, ())
567+
)
568+
for inp in fgraph.inputs
569+
)
570+
fgraph_output_ancestors = tuple(
571+
tuple(toposort_indices[inp.owner] for inp in out.owner.inputs)
572+
for out in fgraph.outputs
573+
if out.owner is not None # constant outputs
574+
)
575+
544576
# Compose individual cache_keys into a global key for the FunctionGraph
545577
fgraph_key = sha256(
546-
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
578+
f"({type(fgraph)}, {tuple(cache_keys)}, {fgraph_input_clients}, {fgraph_output_ancestors})".encode()
547579
).hexdigest()
548580
return numba_njit(py_func), fgraph_key

tests/link/numba/test_basic.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,29 @@
77
import pytest
88
import scipy
99

10-
from pytensor.compile import SymbolicInput
11-
from pytensor.tensor.utils import hash_from_ndarray
12-
1310

1411
numba = pytest.importorskip("numba")
1512

1613
import pytensor.scalar as ps
1714
import pytensor.tensor as pt
1815
from pytensor import config, shared
16+
from pytensor.compile import SymbolicInput
1917
from pytensor.compile.function import function
2018
from pytensor.compile.mode import Mode
2119
from pytensor.graph.basic import Apply, Variable
20+
from pytensor.graph.fg import FunctionGraph
2221
from pytensor.graph.op import Op
2322
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
2423
from pytensor.graph.type import Type
2524
from pytensor.link.numba.dispatch import basic as numba_basic
26-
from pytensor.link.numba.dispatch.basic import cache_key_for_constant
25+
from pytensor.link.numba.dispatch.basic import (
26+
cache_key_for_constant,
27+
numba_funcify_and_cache_key,
28+
)
2729
from pytensor.link.numba.linker import NumbaLinker
2830
from pytensor.scalar.basic import ScalarOp, as_scalar
2931
from pytensor.tensor.elemwise import Elemwise
32+
from pytensor.tensor.utils import hash_from_ndarray
3033

3134

3235
if TYPE_CHECKING:
@@ -652,3 +655,49 @@ def impl(x):
652655
outs[2].owner.op, outs[2].owner
653656
)
654657
assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2
658+
659+
660+
def test_fgraph_cache_key():
661+
x = pt.scalar("x")
662+
log_x = pt.log(x)
663+
graphs = [
664+
pt.exp(x) / log_x,
665+
log_x / pt.exp(x),
666+
pt.exp(log_x) / x,
667+
x / pt.exp(log_x),
668+
pt.exp(log_x) / log_x,
669+
log_x / pt.exp(log_x),
670+
]
671+
672+
def generate_and_validate_key(fg):
673+
_, key = numba_funcify_and_cache_key(fg)
674+
assert key is not None
675+
_, key_again = numba_funcify_and_cache_key(fg)
676+
assert key == key_again # Check its stable
677+
return key
678+
679+
keys = []
680+
for graph in graphs:
681+
fg = FunctionGraph([x], [graph], clone=False)
682+
keys.append(generate_and_validate_key(fg))
683+
# Check keys are unique
684+
assert len(set(keys)) == len(graphs)
685+
686+
# Extra unused input should alter the key, because it changes the function signature
687+
y = pt.scalar("y")
688+
for inputs in [[x, y], [y, x]]:
689+
fg = FunctionGraph(inputs, [graphs[0]], clone=False)
690+
keys.append(generate_and_validate_key(fg))
691+
assert len(set(keys)) == len(graphs) + 2
692+
693+
# Adding an input as an output should also change the key
694+
for outputs in [
695+
[graphs[0], x],
696+
[x, graphs[0]],
697+
[x, x, graphs[0]],
698+
[x, graphs[0], x],
699+
[graphs[0], x, x],
700+
]:
701+
fg = FunctionGraph([x], outputs, clone=False)
702+
keys.append(generate_and_validate_key(fg))
703+
assert len(set(keys)) == len(graphs) + 2 + 5

0 commit comments

Comments
 (0)