diff --git a/pytensor/link/numba/cache.py b/pytensor/link/numba/cache.py index a93ad09cd2..09d3c05830 100644 --- a/pytensor/link/numba/cache.py +++ b/pytensor/link/numba/cache.py @@ -1,6 +1,5 @@ from collections.abc import Callable from hashlib import sha256 -from pathlib import Path from pickle import dump from tempfile import NamedTemporaryFile from typing import Any @@ -64,8 +63,8 @@ def get_disambiguator(self): @classmethod def from_function(cls, py_func, py_file): """Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS.""" - if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS: - return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func]) + if py_func in CACHED_SRC_FUNCTIONS and config.numba__cache: + return cls(py_func, py_file, CACHED_SRC_FUNCTIONS[py_func]) # Register our locator at the front of Numba's locator list diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 35f0af32d5..4ba359f782 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -5,12 +5,13 @@ import numba import numpy as np +from numba import NumbaPerformanceWarning, NumbaWarning from numba import njit as _njit from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from pytensor import config from pytensor.graph.basic import Apply, Constant -from pytensor.graph.fg import FunctionGraph +from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.type import Type from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType @@ -23,6 +24,35 @@ from pytensor.tensor.utils import hash_from_ndarray +def _filter_numba_warnings(): + # Suppress large global arrays cache warning for internal functions + # We have to add an ansi escape code for optional bold text by numba + # TODO: We could avoid inlining large constants and pass them at runtime + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + 'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals' + ), + category=NumbaWarning, + ) + + # Disable loud / incorrect warnings from Numba + # https://github.com/numba/numba/issues/10086 + # TODO: Would be much better if we could disable only for our functions + warnings.filterwarnings( + "ignore", + message=( + "(\x1b\\[1m)*" # ansi escape code for bold text + r"np\.dot\(\) is faster on contiguous arrays" + ), + category=NumbaPerformanceWarning, + ) + + +_filter_numba_warnings() + + def numba_njit( *args, fastmath=None, final_function: bool = False, **kwargs ) -> Callable: @@ -501,28 +531,44 @@ def numba_funcify_FunctionGraph( cache_keys = [] toposort = fgraph.toposort() clients = fgraph.clients - toposort_indices = {node: i for i, node in enumerate(toposort)} - # Add dummy output clients which are not included of the toposort + toposort_indices: dict[Apply | None, int] = { + node: i for i, node in enumerate(toposort) + } + # Use -1 for root inputs / constants whose owner is None + toposort_indices[None] = -1 + # Add dummy output nodes which are not included of the toposort toposort_indices |= { - clients[out][0][0]: i - for i, out in enumerate(fgraph.outputs, start=len(toposort)) + out_node: i + len(toposort) + for i, out in enumerate(fgraph.outputs) + for out_node, _ in clients[out] + if isinstance(out_node.op, Output) and out_node.op.idx == i } - def op_conversion_and_key_collection(*args, **kwargs): + def op_conversion_and_key_collection(op, *args, node, **kwargs): # Convert an Op to a funcified function and store the cache_key # We also Cache each Op so Numba can do less work next time it sees it - func, key = numba_funcify_ensure_cache(*args, **kwargs) - cache_keys.append(key) + func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs) + if key is None: + cache_keys.append(key) + else: + # Add graph coordinate information (input edges and node location) + cache_keys.append( + ( + toposort_indices[node], + tuple(toposort_indices[inp.owner] for inp in node.inputs), + key, + ) + ) return func def type_conversion_and_key_collection(value, variable, **kwargs): # Convert a constant type to a numba compatible one and compute a cache key for it - # We need to know where in the graph the constants are used - # Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same + # Add graph coordinate information (client edges) # FIXME: It doesn't make sense to call type_conversion on non-constants, - # but that's what fgraph_to_python currently does. We appease it, but don't consider for caching + # but that's what fgraph_to_python currently does. + # We appease it, but don't consider for caching if isinstance(variable, Constant): client_indices = tuple( (toposort_indices[node], inp_idx) for node, inp_idx in clients[variable] @@ -541,8 +587,24 @@ def type_conversion_and_key_collection(value, variable, **kwargs): # If a single element couldn't be cached, we can't cache the whole FunctionGraph either fgraph_key = None else: + # Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges) + # Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python` + fgraph_input_clients = tuple( + tuple( + (toposort_indices[node], inp_idx) + # Disconnect inputs don't have clients + for node, inp_idx in clients.get(inp, ()) + ) + for inp in fgraph.inputs + ) + fgraph_output_ancestors = tuple( + tuple(toposort_indices[inp.owner] for inp in out.owner.inputs) + for out in fgraph.outputs + if out.owner is not None # constant outputs + ) + # Compose individual cache_keys into a global key for the FunctionGraph fgraph_key = sha256( - f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode() + f"({type(fgraph)}, {tuple(cache_keys)}, {fgraph_input_clients}, {fgraph_output_ancestors})".encode() ).hexdigest() return numba_njit(py_func), fgraph_key diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 4a4d9b319d..50af695a2e 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -14,7 +14,6 @@ from pytensor.link.numba.dispatch.cython_support import wrap_cython_function from pytensor.link.utils import ( get_name_for_object, - unique_name_generator, ) from pytensor.scalar.basic import ( Add, @@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs): scalar_func_numba = generate_fallback_impl(op, node, **kwargs) scalar_op_fn_name = get_name_for_object(scalar_func_numba) - + prefix = "x" if scalar_func_name != "x" else "y" + input_names = [f"{prefix}{i}" for i in range(len(node.inputs))] + input_signature = ", ".join(input_names) global_env = {"scalar_func_numba": scalar_func_numba} if input_inner_dtypes is None and output_inner_dtype is None: - unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" - ) - input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs) if not has_pyx_skip_dispatch: scalar_op_src = f""" -def {scalar_op_fn_name}({input_names}): - return scalar_func_numba({input_names}) +def {scalar_op_fn_name}({input_signature}): + return scalar_func_numba({input_signature}) """ else: scalar_op_src = f""" -def {scalar_op_fn_name}({input_names}): - return scalar_func_numba({input_names}, np.intc(1)) +def {scalar_op_fn_name}({input_signature}): + return scalar_func_numba({input_signature}, np.intc(1)) """ else: @@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}): for i, i_dtype in enumerate(input_inner_dtypes) } global_env.update(input_tmp_dtype_names) - - unique_names = unique_name_generator( - [scalar_op_fn_name, "scalar_func_numba", *global_env.keys()], - suffix_sep="_", - ) - - input_names = [unique_names(v, force_unique=True) for v in node.inputs] converted_call_args = ", ".join( f"direct_cast({i_name}, {i_tmp_dtype_name})" for i_name, i_tmp_dtype_name in zip( @@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}): ) if not has_pyx_skip_dispatch: scalar_op_src = f""" -def {scalar_op_fn_name}({", ".join(input_names)}): +def {scalar_op_fn_name}({input_signature}): return direct_cast(scalar_func_numba({converted_call_args}), output_dtype) """ else: scalar_op_src = f""" -def {scalar_op_fn_name}({", ".join(input_names)}): +def {scalar_op_fn_name}({input_signature}): return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype) """ scalar_op_fn = compile_numba_function_src( scalar_op_src, scalar_op_fn_name, - {**globals(), **global_env}, + globals() | global_env, ) # Functions that call a function pointer can't be cached @@ -157,8 +147,8 @@ def switch(condition, x, y): def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str): """Create a Numba-compatible N-ary function from a binary function.""" - unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_") - input_names = [unique_names(v, force_unique=True) for v in inputs] + var_prefix = "x" if binary_op_name != "x" else "y" + input_names = [f"{var_prefix}{i}" for i in range(len(inputs))] input_signature = ", ".join(input_names) output_expr = binary_op.join(input_names) diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 1e8078e477..c7cc4cfd8e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -18,7 +18,6 @@ register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) -from pytensor.link.utils import unique_name_generator from pytensor.tensor import TensorType from pytensor.tensor.rewriting.subtensor import is_full_slice from pytensor.tensor.subtensor import ( @@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields): def numba_funcify_default_subtensor(op, node, **kwargs): """Create a Python function that assembles and uses an index on an array.""" - unique_names = unique_name_generator( - ["subtensor", "incsubtensor", "z"], suffix_sep="_" - ) - - def convert_indices(indices, entry): - if indices and isinstance(entry, Type): - rval = indices.pop(0) - return unique_names(rval) + def convert_indices(indice_names, entry): + if indice_names and isinstance(entry, Type): + return next(indice_names) elif isinstance(entry, slice): return ( - f"slice({convert_indices(indices, entry.start)}, " - f"{convert_indices(indices, entry.stop)}, " - f"{convert_indices(indices, entry.step)})" + f"slice({convert_indices(indice_names, entry.start)}, " + f"{convert_indices(indice_names, entry.stop)}, " + f"{convert_indices(indice_names, entry.step)})" ) elif isinstance(entry, type(None)): return "None" @@ -166,13 +160,15 @@ def convert_indices(indices, entry): op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor ) index_start_idx = 1 + int(set_or_inc) - - input_names = [unique_names(v, force_unique=True) for v in node.inputs] op_indices = list(node.inputs[index_start_idx:]) idx_list = getattr(op, "idx_list", None) + idx_names = [f"idx_{i}" for i in range(len(op_indices))] + input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names] + + idx_names_iterator = iter(idx_names) indices_creation_src = ( - tuple(convert_indices(op_indices, idx) for idx in idx_list) + tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list) if idx_list else tuple(input_names[index_start_idx:]) ) @@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}): function_name=function_name, global_env=globals() | {"np": np}, ) - cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor") + cache_key = subtensor_op_cache_key( + op, func="numba_funcify_default_subtensor", version=1 + ) return numba_basic.numba_njit(func, boundscheck=True), cache_key @@ -409,6 +407,8 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs): op, func="multiple_integer_vector_indexing", y_is_broadcasted=y_is_broadcasted, + first_axis=first_axis, + last_axis=last_axis, ) return ret_func, cache_key diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 9553273f99..3c7bda9a15 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -10,7 +10,6 @@ register_funcify_and_cache_key, register_funcify_default_op_cache_key, ) -from pytensor.link.utils import unique_name_generator from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -28,15 +27,7 @@ @register_funcify_default_op_cache_key(AllocEmpty) def numba_funcify_AllocEmpty(op, node, **kwargs): - global_env = { - "np": np, - "dtype": np.dtype(op.dtype), - } - - unique_names = unique_name_generator( - ["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" - ) - shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] + shape_var_names = [f"sh{i}" for i in range(len(node.inputs))] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( @@ -56,7 +47,7 @@ def allocempty({", ".join(shape_var_names)}): """ alloc_fn = compile_numba_function_src( - alloc_def_src, "allocempty", {**globals(), **global_env} + alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)} ) return numba_basic.numba_njit(alloc_fn) @@ -64,13 +55,7 @@ def allocempty({", ".join(shape_var_names)}): @register_funcify_and_cache_key(Alloc) def numba_funcify_Alloc(op, node, **kwargs): - global_env = {"np": np} - - unique_names = unique_name_generator( - ["np", "alloc", "val_np", "val", "scalar_shape", "res"], - suffix_sep="_", - ) - shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] + shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( @@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}): alloc_fn = compile_numba_function_src( alloc_def_src, "alloc", - {**globals(), **global_env}, + globals() | {"np": np}, ) cache_key = sha256( @@ -207,14 +192,7 @@ def eye(N, M, k): @register_funcify_default_op_cache_key(MakeVector) def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) - - global_env = {"np": np, "dtype": dtype} - - unique_names = unique_name_generator( - ["np"], - suffix_sep="_", - ) - input_names = [unique_names(v, force_unique=True) for v in node.inputs] + input_names = [f"x{i}" for i in range(len(node.inputs))] def create_list_string(x): args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else [])) @@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}): makevector_fn = compile_numba_function_src( makevector_def_src, "makevector", - {**globals(), **global_env}, + globals() | {"np": np, "dtype": dtype}, ) return numba_basic.numba_njit(makevector_fn) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 43e4885840..1314a145b3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -7,26 +7,30 @@ import pytest import scipy -from pytensor.compile import SymbolicInput -from pytensor.tensor.utils import hash_from_ndarray - numba = pytest.importorskip("numba") import pytensor.scalar as ps import pytensor.tensor as pt from pytensor import config, shared +from pytensor.compile import SymbolicInput from pytensor.compile.function import function from pytensor.compile.mode import Mode from pytensor.graph.basic import Apply, Variable +from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.type import Type from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import cache_key_for_constant +from pytensor.link.numba.dispatch.basic import ( + _filter_numba_warnings, + cache_key_for_constant, + numba_funcify_and_cache_key, +) from pytensor.link.numba.linker import NumbaLinker from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.utils import hash_from_ndarray if TYPE_CHECKING: @@ -450,14 +454,46 @@ def test_scalar_return_value_conversion(): assert isinstance(x_fn(1.0), np.ndarray) -@pytest.mark.filterwarnings("error") -def test_cache_warning_suppressed(): - x = pt.vector("x", shape=(5,), dtype="float64") - out = pt.psi(x) * 2 - fn = function([x], out, mode="NUMBA") +class TestNumbaWarnings: + def setup_method(self, method): + # Pytest messes up with the package filters, reenable here for testing + _filter_numba_warnings() + + @pytest.mark.filterwarnings("error") + def test_cache_pointer_func_warning_suppressed(self): + x = pt.vector("x", shape=(5,), dtype="float64") + out = pt.psi(x) * 2 + fn = function([x], out, mode="NUMBA") + + x_test = np.random.uniform(size=5) + np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) - x_test = np.random.uniform(size=5) - np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) + @pytest.mark.filterwarnings("error") + def test_cache_large_global_array_warning_suppressed(self): + rng = np.random.default_rng(458) + large_constant = rng.normal(size=(100000, 5)) + + x = pt.vector("x", shape=(5,), dtype="float64") + out = x * large_constant + fn = function([x], out, mode="NUMBA") + + x_test = rng.uniform(size=5) + np.testing.assert_allclose(fn(x_test), x_test * large_constant) + + @pytest.mark.filterwarnings("error") + def test_contiguous_array_dot_warning_suppressed(self): + A = pt.matrix("A") + b = pt.vector("b") + out = pt.dot(A, b[:, None]) + # Cached functions won't reemit the warning, so we have to disable it + with config.change_flags(numba__cache=False): + fn = function([A, b], out, mode="NUMBA") + + A_test = np.ones((5, 5)) + # Numba actually warns even on contiguous arrays: https://github.com/numba/numba/issues/10086 + # But either way we don't want this warning for users as they have little control over strides + b_test = np.ones((10,))[::2] + np.testing.assert_allclose(fn(A_test, b_test), np.dot(A_test, b_test[:, None])) @pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) @@ -652,3 +688,49 @@ def impl(x): outs[2].owner.op, outs[2].owner ) assert numba.njit(lambda x: fn2_def_cached(x))(test_x) == 2 + + +def test_fgraph_cache_key(): + x = pt.scalar("x") + log_x = pt.log(x) + graphs = [ + pt.exp(x) / log_x, + log_x / pt.exp(x), + pt.exp(log_x) / x, + x / pt.exp(log_x), + pt.exp(log_x) / log_x, + log_x / pt.exp(log_x), + ] + + def generate_and_validate_key(fg): + _, key = numba_funcify_and_cache_key(fg) + assert key is not None + _, key_again = numba_funcify_and_cache_key(fg) + assert key == key_again # Check its stable + return key + + keys = [] + for graph in graphs: + fg = FunctionGraph([x], [graph], clone=False) + keys.append(generate_and_validate_key(fg)) + # Check keys are unique + assert len(set(keys)) == len(graphs) + + # Extra unused input should alter the key, because it changes the function signature + y = pt.scalar("y") + for inputs in [[x, y], [y, x]]: + fg = FunctionGraph(inputs, [graphs[0]], clone=False) + keys.append(generate_and_validate_key(fg)) + assert len(set(keys)) == len(graphs) + 2 + + # Adding an input as an output should also change the key + for outputs in [ + [graphs[0], x], + [x, graphs[0]], + [x, x, graphs[0]], + [x, graphs[0], x], + [graphs[0], x, x], + ]: + fg = FunctionGraph([x], outputs, clone=False) + keys.append(generate_and_validate_key(fg)) + assert len(set(keys)) == len(graphs) + 2 + 5