Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions pytensor/link/numba/cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Callable
from hashlib import sha256
from pathlib import Path
from pickle import dump
from tempfile import NamedTemporaryFile
from typing import Any
Expand Down Expand Up @@ -64,8 +63,8 @@ def get_disambiguator(self):
@classmethod
def from_function(cls, py_func, py_file):
"""Create a locator instance for functions stored in CACHED_SRC_FUNCTIONS."""
if config.numba__cache and py_func in CACHED_SRC_FUNCTIONS:
return cls(py_func, Path(py_file).parent, CACHED_SRC_FUNCTIONS[py_func])
if py_func in CACHED_SRC_FUNCTIONS and config.numba__cache:
return cls(py_func, py_file, CACHED_SRC_FUNCTIONS[py_func])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numba wants a string for py_file (to emit warnings)



# Register our locator at the front of Numba's locator list
Expand Down
86 changes: 74 additions & 12 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

import numba
import numpy as np
from numba import NumbaPerformanceWarning, NumbaWarning
from numba import njit as _njit
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401

from pytensor import config
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.type import Type
from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
Expand All @@ -23,6 +24,35 @@
from pytensor.tensor.utils import hash_from_ndarray


def _filter_numba_warnings():
# Suppress large global arrays cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
# TODO: We could avoid inlining large constants and pass them at runtime
warnings.filterwarnings(
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
'Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals'
),
category=NumbaWarning,
)

# Disable loud / incorrect warnings from Numba
# https://github.com/numba/numba/issues/10086
# TODO: Would be much better if we could disable only for our functions
warnings.filterwarnings(
"ignore",
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
r"np\.dot\(\) is faster on contiguous arrays"
),
category=NumbaPerformanceWarning,
)


_filter_numba_warnings()


def numba_njit(
*args, fastmath=None, final_function: bool = False, **kwargs
) -> Callable:
Expand Down Expand Up @@ -501,28 +531,44 @@ def numba_funcify_FunctionGraph(
cache_keys = []
toposort = fgraph.toposort()
clients = fgraph.clients
toposort_indices = {node: i for i, node in enumerate(toposort)}
# Add dummy output clients which are not included of the toposort
toposort_indices: dict[Apply | None, int] = {
node: i for i, node in enumerate(toposort)
}
# Use -1 for root inputs / constants whose owner is None
toposort_indices[None] = -1
# Add dummy output nodes which are not included of the toposort
toposort_indices |= {
clients[out][0][0]: i
for i, out in enumerate(fgraph.outputs, start=len(toposort))
out_node: i + len(toposort)
for i, out in enumerate(fgraph.outputs)
for out_node, _ in clients[out]
if isinstance(out_node.op, Output) and out_node.op.idx == i
}

def op_conversion_and_key_collection(*args, **kwargs):
def op_conversion_and_key_collection(op, *args, node, **kwargs):
# Convert an Op to a funcified function and store the cache_key

# We also Cache each Op so Numba can do less work next time it sees it
func, key = numba_funcify_ensure_cache(*args, **kwargs)
cache_keys.append(key)
func, key = numba_funcify_ensure_cache(op, node=node, *args, **kwargs)
if key is None:
cache_keys.append(key)
else:
# Add graph coordinate information (input edges and node location)
cache_keys.append(
(
toposort_indices[node],
tuple(toposort_indices[inp.owner] for inp in node.inputs),
key,
)
)
return func

def type_conversion_and_key_collection(value, variable, **kwargs):
# Convert a constant type to a numba compatible one and compute a cache key for it

# We need to know where in the graph the constants are used
# Otherwise we would hash stack(x, 5.0, 7.0), and stack(5.0, x, 7.0) the same
# Add graph coordinate information (client edges)
# FIXME: It doesn't make sense to call type_conversion on non-constants,
# but that's what fgraph_to_python currently does. We appease it, but don't consider for caching
# but that's what fgraph_to_python currently does.
# We appease it, but don't consider for caching
if isinstance(variable, Constant):
client_indices = tuple(
(toposort_indices[node], inp_idx) for node, inp_idx in clients[variable]
Expand All @@ -541,8 +587,24 @@ def type_conversion_and_key_collection(value, variable, **kwargs):
# If a single element couldn't be cached, we can't cache the whole FunctionGraph either
fgraph_key = None
else:
# Add graph coordinate information for fgraph inputs (client edges) and fgraph outputs (input edges)
# Constant edges are handled by `type_conversion_and_key_collection` called by `fgraph_to_python`
fgraph_input_clients = tuple(
tuple(
(toposort_indices[node], inp_idx)
# Disconnect inputs don't have clients
for node, inp_idx in clients.get(inp, ())
)
for inp in fgraph.inputs
)
fgraph_output_ancestors = tuple(
tuple(toposort_indices[inp.owner] for inp in out.owner.inputs)
for out in fgraph.outputs
if out.owner is not None # constant outputs
)

# Compose individual cache_keys into a global key for the FunctionGraph
fgraph_key = sha256(
f"({type(fgraph)}, {tuple(cache_keys)}, {len(fgraph.inputs)}, {len(fgraph.outputs)})".encode()
f"({type(fgraph)}, {tuple(cache_keys)}, {fgraph_input_clients}, {fgraph_output_ancestors})".encode()
).hexdigest()
return numba_njit(py_func), fgraph_key
34 changes: 12 additions & 22 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from pytensor.link.numba.dispatch.cython_support import wrap_cython_function
from pytensor.link.utils import (
get_name_for_object,
unique_name_generator,
)
from pytensor.scalar.basic import (
Add,
Expand Down Expand Up @@ -81,23 +80,21 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
scalar_func_numba = generate_fallback_impl(op, node, **kwargs)

scalar_op_fn_name = get_name_for_object(scalar_func_numba)

prefix = "x" if scalar_func_name != "x" else "y"
input_names = [f"{prefix}{i}" for i in range(len(node.inputs))]
input_signature = ", ".join(input_names)
global_env = {"scalar_func_numba": scalar_func_numba}

if input_inner_dtypes is None and output_inner_dtype is None:
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names})
def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_signature})
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
return scalar_func_numba({input_names}, np.intc(1))
def {scalar_op_fn_name}({input_signature}):
return scalar_func_numba({input_signature}, np.intc(1))
"""

else:
Expand All @@ -108,13 +105,6 @@ def {scalar_op_fn_name}({input_names}):
for i, i_dtype in enumerate(input_inner_dtypes)
}
global_env.update(input_tmp_dtype_names)

unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba", *global_env.keys()],
suffix_sep="_",
)

input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join(
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(
Expand All @@ -123,19 +113,19 @@ def {scalar_op_fn_name}({input_names}):
)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}):
def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}), output_dtype)
"""
else:
scalar_op_src = f"""
def {scalar_op_fn_name}({", ".join(input_names)}):
def {scalar_op_fn_name}({input_signature}):
return direct_cast(scalar_func_numba({converted_call_args}, np.intc(1)), output_dtype)
"""

scalar_op_fn = compile_numba_function_src(
scalar_op_src,
scalar_op_fn_name,
{**globals(), **global_env},
globals() | global_env,
)

# Functions that call a function pointer can't be cached
Expand All @@ -157,8 +147,8 @@ def switch(condition, x, y):

def binary_to_nary_func(inputs: list[Variable], binary_op_name: str, binary_op: str):
"""Create a Numba-compatible N-ary function from a binary function."""
unique_names = unique_name_generator(["binary_op_name"], suffix_sep="_")
input_names = [unique_names(v, force_unique=True) for v in inputs]
var_prefix = "x" if binary_op_name != "x" else "y"
input_names = [f"{var_prefix}{i}" for i in range(len(inputs))]
input_signature = ", ".join(input_names)
output_expr = binary_op.join(input_names)

Expand Down
32 changes: 16 additions & 16 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor import TensorType
from pytensor.tensor.rewriting.subtensor import is_full_slice
from pytensor.tensor.subtensor import (
Expand Down Expand Up @@ -143,19 +142,14 @@ def subtensor_op_cache_key(op, **extra_fields):
def numba_funcify_default_subtensor(op, node, **kwargs):
"""Create a Python function that assembles and uses an index on an array."""

unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)

def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return unique_names(rval)
def convert_indices(indice_names, entry):
if indice_names and isinstance(entry, Type):
return next(indice_names)
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
f"{convert_indices(indices, entry.stop)}, "
f"{convert_indices(indices, entry.step)})"
f"slice({convert_indices(indice_names, entry.start)}, "
f"{convert_indices(indice_names, entry.stop)}, "
f"{convert_indices(indice_names, entry.step)})"
)
elif isinstance(entry, type(None)):
return "None"
Expand All @@ -166,13 +160,15 @@ def convert_indices(indices, entry):
op, IncSubtensor | AdvancedIncSubtensor1 | AdvancedIncSubtensor
)
index_start_idx = 1 + int(set_or_inc)

input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(op, "idx_list", None)
idx_names = [f"idx_{i}" for i in range(len(op_indices))]

input_names = ["x", "y", *idx_names] if set_or_inc else ["x", *idx_names]

idx_names_iterator = iter(idx_names)
indices_creation_src = (
tuple(convert_indices(op_indices, idx) for idx in idx_list)
tuple(convert_indices(idx_names_iterator, idx) for idx in idx_list)
if idx_list
else tuple(input_names[index_start_idx:])
)
Expand Down Expand Up @@ -220,7 +216,9 @@ def {function_name}({", ".join(input_names)}):
function_name=function_name,
global_env=globals() | {"np": np},
)
cache_key = subtensor_op_cache_key(op, func="numba_funcify_default_subtensor")
cache_key = subtensor_op_cache_key(
op, func="numba_funcify_default_subtensor", version=1
)
return numba_basic.numba_njit(func, boundscheck=True), cache_key


Expand Down Expand Up @@ -409,6 +407,8 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
op,
func="multiple_integer_vector_indexing",
y_is_broadcasted=y_is_broadcasted,
first_axis=first_axis,
last_axis=last_axis,
)
return ret_func, cache_key

Expand Down
34 changes: 6 additions & 28 deletions pytensor/link/numba/dispatch/tensor_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.utils import unique_name_generator
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
Expand All @@ -28,15 +27,7 @@

@register_funcify_default_op_cache_key(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
"dtype": np.dtype(op.dtype),
}

unique_names = unique_name_generator(
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_names = [f"sh{i}" for i in range(len(node.inputs))]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
Expand All @@ -56,21 +47,15 @@ def allocempty({", ".join(shape_var_names)}):
"""

alloc_fn = compile_numba_function_src(
alloc_def_src, "allocempty", {**globals(), **global_env}
alloc_def_src, "allocempty", globals() | {"np": np, "dtype": np.dtype(op.dtype)}
)

return numba_basic.numba_njit(alloc_fn)


@register_funcify_and_cache_key(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np}

unique_names = unique_name_generator(
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
shape_var_names = [f"sh{i}" for i in range(len(node.inputs) - 1)]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
Expand Down Expand Up @@ -102,7 +87,7 @@ def alloc(val, {", ".join(shape_var_names)}):
alloc_fn = compile_numba_function_src(
alloc_def_src,
"alloc",
{**globals(), **global_env},
globals() | {"np": np},
)

cache_key = sha256(
Expand Down Expand Up @@ -207,14 +192,7 @@ def eye(N, M, k):
@register_funcify_default_op_cache_key(MakeVector)
def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype)

global_env = {"np": np, "dtype": dtype}

unique_names = unique_name_generator(
["np"],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
input_names = [f"x{i}" for i in range(len(node.inputs))]

def create_list_string(x):
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
Expand All @@ -228,7 +206,7 @@ def makevector({", ".join(input_names)}):
makevector_fn = compile_numba_function_src(
makevector_def_src,
"makevector",
{**globals(), **global_env},
globals() | {"np": np, "dtype": dtype},
)

return numba_basic.numba_njit(makevector_fn)
Expand Down
Loading
Loading