diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index f99b8240ca..6763509d75 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -1099,12 +1099,6 @@ def add_scan_configvars(): def add_numba_configvars(): - config.add( - "numba__vectorize_target", - ("Default target for numba.vectorize."), - EnumStr("cpu", ["parallel", "cuda"], mutable=True), - in_c_key=False, - ) config.add( "numba__fastmath", ("If True, use Numba's fastmath mode."), diff --git a/pytensor/configparser.py b/pytensor/configparser.py index c7da71426d..d33c970ba1 100644 --- a/pytensor/configparser.py +++ b/pytensor/configparser.py @@ -157,7 +157,6 @@ class PyTensorConfigParser: scan__allow_gc: bool scan__allow_output_prealloc: bool # add_numba_configvars - numba__vectorize_target: str numba__fastmath: bool numba__cache: bool # add_caching_dir_configvars diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1fefb1d06d..50e61a27ab 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,8 +9,10 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar +import pytensor.link.numba.dispatch.shape import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg +import pytensor.link.numba.dispatch.sort import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.tensor_basic diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 87b8e380d3..0d4217a786 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -1,53 +1,27 @@ -import operator -import sys import warnings -from copy import copy from functools import singledispatch -from textwrap import dedent import numba -import numba.np.unsafe.ndarray as numba_ndarray import numpy as np -import scipy -import scipy.special -from llvmlite import ir -from numba import types -from numba.core.errors import NumbaWarning, TypingError +from numba.core.errors import NumbaWarning from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 -from numba.extending import box, overload from pytensor import In, config from pytensor.compile import NUMBA from pytensor.compile.builders import OpFromGraph from pytensor.compile.function.types import add_supervisor_to_fgraph -from pytensor.compile.ops import DeepCopyOp +from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.graph.basic import Apply from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.ifelse import IfElse from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( - compile_function_src, fgraph_to_python, ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType -from pytensor.tensor.basic import Nonzero -from pytensor.tensor.blas import BatchedDot -from pytensor.tensor.math import Dot -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.slinalg import Solve -from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType -from pytensor.tensor.type_other import MakeSlice, NoneConst - - -def global_numba_func(func): - """Use to return global numba functions in numba_funcify_*. - - This allows tests to remove the compilation using mock. - """ - return func def numba_njit(*args, fastmath=None, **kwargs): @@ -87,13 +61,6 @@ def numba_njit(*args, fastmath=None, **kwargs): return numba.njit(*args, fastmath=fastmath, **kwargs) -def numba_vectorize(*args, **kwargs): - if len(args) > 0 and callable(args[0]): - return numba.vectorize(*args[1:], cache=config.numba__cache, **kwargs)(args[0]) - - return numba.vectorize(*args, cache=config.numba__cache, **kwargs) - - def get_numba_type( pytensor_type: Type, layout: str = "A", @@ -167,83 +134,6 @@ def create_numba_signature( return numba.types.void(*input_types) -def slice_new(self, start, stop, step): - fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) - fn = self._get_function(fnty, name="PySlice_New") - return self.builder.call(fn, [start, stop, step]) - - -def enable_slice_boxing(): - """Enable boxing for Numba's native ``slice``s. - - TODO: this can be removed when https://github.com/numba/numba/pull/6939 is - merged and a release is made. - """ - - @box(types.SliceType) - def box_slice(typ, val, c): - """Implement boxing for ``slice`` objects in Numba. - - This makes it possible to return an Numba's internal representation of a - ``slice`` object as a proper ``slice`` to Python. - """ - start = c.builder.extract_value(val, 0) - stop = c.builder.extract_value(val, 1) - - none_val = ir.Constant(ir.IntType(64), sys.maxsize) - - start_is_none = c.builder.icmp_signed("==", start, none_val) - start = c.builder.select( - start_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, start), - ) - - stop_is_none = c.builder.icmp_signed("==", stop, none_val) - stop = c.builder.select( - stop_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, stop), - ) - - if typ.has_step: - step = c.builder.extract_value(val, 2) - step_is_none = c.builder.icmp_signed("==", step, none_val) - step = c.builder.select( - step_is_none, - c.pyapi.get_null_object(), - c.box(types.int64, step), - ) - else: - step = c.pyapi.get_null_object() - - slice_val = slice_new(c.pyapi, start, stop, step) - - return slice_val - - @numba.extending.overload(operator.contains) - def in_seq_empty_tuple(x, y): - if isinstance(x, types.Tuple) and not x.types: - return lambda x, y: False - - -enable_slice_boxing() - - -def to_scalar(x): - return np.asarray(x).item() - - -@numba.extending.overload(to_scalar) -def impl_to_scalar(x): - if isinstance(x, numba.types.Number | numba.types.Boolean): - return lambda x: x - elif isinstance(x, numba.types.Array): - return lambda x: x.item() - else: - raise TypingError(f"{x} must be a scalar compatible type.") - - def create_tuple_creator(f, n): """Construct a compile-time ``tuple``-comprehension-like loop. @@ -276,6 +166,55 @@ def create_arg_string(x): return args +@numba.extending.intrinsic +def direct_cast(typingctx, val, typ): + if isinstance(typ, numba.types.TypeRef): + casted = typ.instance_type + elif isinstance(typ, numba.types.DTypeSpec): + casted = typ.dtype + else: + casted = typ + + sig = casted(casted, typ) + + def codegen(context, builder, signature, args): + val, _ = args + context.nrt.incref(builder, signature.return_type, val) + return val + + return sig, codegen + + +def int_to_float_fn(inputs, out_dtype): + """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" + + if ( + all(inp.type.dtype == out_dtype for inp in inputs) + and np.dtype(out_dtype).kind == "f" + ): + + @numba_njit(inline="always") + def inputs_cast(x): + return x + + elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): + args_dtype = np.dtype(f"f{out_dtype.itemsize}") + + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + else: + args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) + args_dtype = np.dtype(f"f{args_dtype_sz}") + + @numba_njit(inline="always") + def inputs_cast(x): + return x.astype(args_dtype) + + return inputs_cast + + @singledispatch def numba_typify(data, dtype=None, **kwargs): return data @@ -341,6 +280,22 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): return generate_fallback_impl(op, node, storage_map, **kwargs) +@numba_funcify.register(FunctionGraph) +def numba_funcify_FunctionGraph( + fgraph, + node=None, + fgraph_name="numba_funcified_fgraph", + **kwargs, +): + return fgraph_to_python( + fgraph, + numba_funcify, + type_conversion_fn=numba_typify, + fgraph_name=fgraph_name, + **kwargs, + ) + + @numba_funcify.register(OpFromGraph) def numba_funcify_OpFromGraph(op, node=None, **kwargs): _ = kwargs.pop("storage_map", None) @@ -373,335 +328,30 @@ def opfromgraph(*inputs): return opfromgraph -@numba_funcify.register(FunctionGraph) -def numba_funcify_FunctionGraph( - fgraph, - node=None, - fgraph_name="numba_funcified_fgraph", - **kwargs, -): - return fgraph_to_python( - fgraph, - numba_funcify, - type_conversion_fn=numba_typify, - fgraph_name=fgraph_name, - **kwargs, - ) - - -def deepcopyop(x): - return copy(x) - - -@overload(deepcopyop) -def dispatch_deepcopyop(x): - if isinstance(x, types.Array): - return lambda x: np.copy(x) +@numba_funcify.register(TypeCastingOp) +def numba_funcify_type_casting(op, **kwargs): + @numba_njit + def identity(x): + return x - return lambda x: x + return identity @numba_funcify.register(DeepCopyOp) def numba_funcify_DeepCopyOp(op, node, **kwargs): - return deepcopyop - - -@numba_funcify.register(MakeSlice) -def numba_funcify_MakeSlice(op, **kwargs): - @numba_njit - def makeslice(*x): - return slice(*x) - - return makeslice - - -@numba_funcify.register(Shape) -def numba_funcify_Shape(op, **kwargs): - @numba_njit - def shape(x): - return np.asarray(np.shape(x)) - - return shape - - -@numba_funcify.register(Shape_i) -def numba_funcify_Shape_i(op, **kwargs): - i = op.i - - @numba_njit - def shape_i(x): - return np.asarray(np.shape(x)[i]) - - return shape_i - - -@numba_funcify.register(SortOp) -def numba_funcify_SortOp(op, node, **kwargs): - @numba_njit - def sort_f(a, axis): - axis = axis.item() - - a_swapped = np.swapaxes(a, axis, -1) - a_sorted = np.sort(a_swapped) - a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) - - return a_sorted_swapped - - if op.kind != "quicksort": - warnings.warn( - ( - f'Numba function sort doesn\'t support kind="{op.kind}"' - " switching to `quicksort`." - ), - UserWarning, - ) - - return sort_f - - -@numba_funcify.register(ArgSortOp) -def numba_funcify_ArgSortOp(op, node, **kwargs): - def argsort_f_kind(kind): - @numba_njit - def argort_vec(X, axis): - axis = axis.item() - - Y = np.swapaxes(X, axis, 0) - result = np.empty_like(Y, dtype="int64") - - indices = list(np.ndindex(Y.shape[1:])) - - for idx in indices: - result[(slice(None), *idx)] = np.argsort( - Y[(slice(None), *idx)], kind=kind - ) - - result = np.swapaxes(result, 0, axis) - - return result - - return argort_vec - - kind = op.kind - - if kind not in ["quicksort", "mergesort"]: - kind = "quicksort" - warnings.warn( - ( - f'Numba function argsort doesn\'t support kind="{op.kind}"' - " switching to `quicksort`." - ), - UserWarning, - ) - - return argsort_f_kind(kind) - - -@numba.extending.intrinsic -def direct_cast(typingctx, val, typ): - if isinstance(typ, numba.types.TypeRef): - casted = typ.instance_type - elif isinstance(typ, numba.types.DTypeSpec): - casted = typ.dtype - else: - casted = typ - - sig = casted(casted, typ) - - def codegen(context, builder, signature, args): - val, _ = args - context.nrt.incref(builder, signature.return_type, val) - return val - - return sig, codegen - - -@numba_funcify.register(Reshape) -def numba_funcify_Reshape(op, **kwargs): - ndim = op.ndim - - if ndim == 0: + if isinstance(node.inputs[0].type, TensorType): @numba_njit - def reshape(x, shape): - return np.asarray(x.item()) + def deepcopy(x): + return np.copy(x) else: @numba_njit - def reshape(x, shape): - # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. - return np.reshape( - np.ascontiguousarray(np.asarray(x)), - numba_ndarray.to_fixed_tuple(shape, ndim), - ) - - return reshape - - -@numba_funcify.register(SpecifyShape) -def numba_funcify_SpecifyShape(op, node, **kwargs): - shape_inputs = node.inputs[1:] - shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] - - func_conditions = [ - f"assert x.shape[{i}] == {shape_input_names}" - for i, (shape_input, shape_input_names) in enumerate( - zip(shape_inputs, shape_input_names, strict=True) - ) - if shape_input is not NoneConst - ] - - func = dedent( - f""" - def specify_shape(x, {create_arg_string(shape_input_names)}): - {"; ".join(func_conditions)} + def deepcopy(x): return x - """ - ) - - specify_shape = compile_function_src(func, "specify_shape", globals()) - return numba_njit(specify_shape) - - -def int_to_float_fn(inputs, out_dtype): - """Create a Numba function that converts integer and boolean ``ndarray``s to floats.""" - if ( - all(inp.type.dtype == out_dtype for inp in inputs) - and np.dtype(out_dtype).kind == "f" - ): - - @numba_njit(inline="always") - def inputs_cast(x): - return x - - elif any(i.type.numpy_dtype.kind in "uib" for i in inputs): - args_dtype = np.dtype(f"f{out_dtype.itemsize}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - else: - args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs) - args_dtype = np.dtype(f"f{args_dtype_sz}") - - @numba_njit(inline="always") - def inputs_cast(x): - return x.astype(args_dtype) - - return inputs_cast - - -@numba_funcify.register(Dot) -def numba_funcify_Dot(op, node, **kwargs): - # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. - x, y = node.inputs - [out] = node.outputs - - x_dtype = x.type.dtype - y_dtype = y.type.dtype - dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" - out_dtype = out.type.dtype - - if x_dtype == dot_dtype and y_dtype == dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y)) - - elif x_dtype == dot_dtype and y_dtype != dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x, y.astype(dot_dtype))) - - elif x_dtype != dot_dtype and y_dtype == dot_dtype: - - @numba_njit - def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y)) - - else: - - @numba_njit() - def dot(x, y): - return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) - - if out_dtype == dot_dtype: - return dot - - else: - - @numba_njit - def dot_with_cast(x, y): - return dot(x, y).astype(out_dtype) - - return dot_with_cast - - -@numba_funcify.register(Solve) -def numba_funcify_Solve(op, node, **kwargs): - assume_a = op.assume_a - # check_finite = op.check_finite - - if assume_a != "gen": - lower = op.lower - - warnings.warn( - ( - "Numba will use object mode to allow the " - "`compute_uv` argument to `numpy.linalg.svd`." - ), - UserWarning, - ) - - ret_sig = get_numba_type(node.outputs[0].type) - - @numba_njit - def solve(a, b): - with numba.objmode(ret=ret_sig): - ret = scipy.linalg.solve_triangular( - a, - b, - lower=lower, - # check_finite=check_finite - ) - return ret - - else: - out_dtype = node.outputs[0].type.numpy_dtype - inputs_cast = int_to_float_fn(node.inputs, out_dtype) - - @numba_njit - def solve(a, b): - return np.linalg.solve( - inputs_cast(a), - inputs_cast(b), - # assume_a=assume_a, - # check_finite=check_finite, - ).astype(out_dtype) - - return solve - - -@numba_funcify.register(BatchedDot) -def numba_funcify_BatchedDot(op, node, **kwargs): - dtype = node.outputs[0].type.numpy_dtype - - @numba_njit - def batched_dot(x, y): - # Numba does not support 3D matmul - # https://github.com/numba/numba/issues/3804 - shape = x.shape[:-1] + y.shape[2:] - z0 = np.empty(shape, dtype=dtype) - for i in range(z0.shape[0]): - z0[i] = np.dot(x[i], y[i]) - - return z0 - - return batched_dot + return deepcopy @numba_funcify.register(IfElse) @@ -731,15 +381,3 @@ def ifelse(cond, *args): return res[0] return ifelse - - -@numba_funcify.register(Nonzero) -def numba_funcify_Nonzero(op, node, **kwargs): - @numba_njit - def nonzero(a): - result_tuple = np.nonzero(a) - if a.ndim == 1: - return result_tuple[0] - return list(result_tuple) - - return nonzero diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 5ee056f43f..807a60a6d3 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -35,8 +35,9 @@ scalar_maximum, ) from pytensor.scalar.basic import add as add_as +from pytensor.tensor.blas import BatchedDot from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise -from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum +from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -599,3 +600,68 @@ def argmax(x): return max_idx_res return argmax + + +@numba_funcify.register(Dot) +def numba_funcify_Dot(op, node, **kwargs): + # Numba's `np.dot` does not support integer dtypes, so we need to cast to float. + x, y = node.inputs + [out] = node.outputs + + x_dtype = x.type.dtype + y_dtype = y.type.dtype + dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" + out_dtype = out.type.dtype + + if x_dtype == dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y)) + + elif x_dtype == dot_dtype and y_dtype != dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x, y.astype(dot_dtype))) + + elif x_dtype != dot_dtype and y_dtype == dot_dtype: + + @numba_njit + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y)) + + else: + + @numba_njit() + def dot(x, y): + return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) + + if out_dtype == dot_dtype: + return dot + + else: + + @numba_njit + def dot_with_cast(x, y): + return dot(x, y).astype(out_dtype) + + return dot_with_cast + + +@numba_funcify.register(BatchedDot) +def numba_funcify_BatchedDot(op, node, **kwargs): + dtype = node.outputs[0].type.numpy_dtype + + @numba_njit + def batched_dot(x, y): + # Numba does not support 3D matmul + # https://github.com/numba/numba/issues/3804 + shape = x.shape[:-1] + y.shape[2:] + z0 = np.empty(shape, dtype=dtype) + for i in range(z0.shape[0]): + z0[i] = np.dot(x[i], y[i]) + + return z0 + + return batched_dot diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index f7700acf47..5f8495b804 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -26,7 +26,7 @@ def numba_funcify_Bartlett(op, **kwargs): @numba_basic.numba_njit(inline="always") def bartlett(x): - return np.bartlett(numba_basic.to_scalar(x)) + return np.bartlett(x.item()) return bartlett @@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): @numba_basic.numba_njit def filldiagonaloffset(a, val, offset): height, width = a.shape - + offset_item = offset.item() if offset >= 0: - start = numba_basic.to_scalar(offset) + start = offset_item num_of_step = min(min(width, height), width - offset) else: - start = -numba_basic.to_scalar(offset) * a.shape[1] + start = -offset_item * a.shape[1] num_of_step = min(min(width, height), height + offset) step = a.shape[1] + 1 diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 4e0019b74b..e26c9371ed 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor.compile.ops import TypeCastingOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -197,7 +196,6 @@ def cast(x): @numba_funcify.register(Identity) -@numba_funcify.register(TypeCastingOp) def numba_funcify_type_casting(op, **kwargs): @numba_basic.numba_njit def identity(x): @@ -210,14 +208,10 @@ def identity(x): def numba_funcify_Clip(op, **kwargs): @numba_basic.numba_njit def clip(x, min_val, max_val): - x = numba_basic.to_scalar(x) - min_scalar = numba_basic.to_scalar(min_val) - max_scalar = numba_basic.to_scalar(max_val) - - if x < min_scalar: - return min_scalar - elif x > max_scalar: - return max_scalar + if x < min_val: + return min_val + elif x > max_val: + return max_val else: return x diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index c75a4cf890..694f341ed4 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -365,7 +365,7 @@ def add_output_storage_post_proc_stmt( storage_alloc_stmts.append( dedent( f""" - {storage_size_name} = to_numba_scalar({outer_in_name}) + {storage_size_name} = ({outer_in_name}).item() {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) """ ).strip() @@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}): """ global_env = { + "np": np, "scan_inner_func": scan_inner_func, - "to_numba_scalar": numba_basic.to_scalar, } - global_env["np"] = np scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) diff --git a/pytensor/link/numba/dispatch/shape.py b/pytensor/link/numba/dispatch/shape.py new file mode 100644 index 0000000000..f7f2c0890d --- /dev/null +++ b/pytensor/link/numba/dispatch/shape.py @@ -0,0 +1,78 @@ +from textwrap import dedent + +import numpy as np +from numba.np.unsafe import ndarray as numba_ndarray + +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit +from pytensor.link.utils import compile_function_src +from pytensor.tensor import NoneConst +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape + + +@numba_funcify.register(Shape) +def numba_funcify_Shape(op, **kwargs): + @numba_njit + def shape(x): + return np.asarray(np.shape(x)) + + return shape + + +@numba_funcify.register(Shape_i) +def numba_funcify_Shape_i(op, **kwargs): + i = op.i + + @numba_njit + def shape_i(x): + return np.asarray(np.shape(x)[i]) + + return shape_i + + +@numba_funcify.register(SpecifyShape) +def numba_funcify_SpecifyShape(op, node, **kwargs): + shape_inputs = node.inputs[1:] + shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))] + + func_conditions = [ + f"assert x.shape[{i}] == {eval_dim_name}, f'SpecifyShape: dim {{{i}}} of input has shape {{x.shape[{i}]}}, expected {{{eval_dim_name}.item()}}.'" + for i, (node_dim_input, eval_dim_name) in enumerate( + zip(shape_inputs, shape_input_names, strict=True) + ) + if node_dim_input is not NoneConst + ] + + func = dedent( + f""" + def specify_shape(x, {create_arg_string(shape_input_names)}): + {"; ".join(func_conditions)} + return x + """ + ) + + specify_shape = compile_function_src(func, "specify_shape", globals()) + return numba_njit(specify_shape) + + +@numba_funcify.register(Reshape) +def numba_funcify_Reshape(op, **kwargs): + ndim = op.ndim + + if ndim == 0: + + @numba_njit + def reshape(x, shape): + return np.asarray(x.item()) + + else: + + @numba_njit + def reshape(x, shape): + # TODO: Use this until https://github.com/numba/numba/issues/7353 is closed. + return np.reshape( + np.ascontiguousarray(np.asarray(x)), + numba_ndarray.to_fixed_tuple(shape, ndim), + ) + + return reshape diff --git a/pytensor/link/numba/dispatch/sort.py b/pytensor/link/numba/dispatch/sort.py new file mode 100644 index 0000000000..bb91d4fc97 --- /dev/null +++ b/pytensor/link/numba/dispatch/sort.py @@ -0,0 +1,63 @@ +import warnings + +import numpy as np + +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.tensor.sort import ArgSortOp, SortOp + + +@numba_funcify.register(SortOp) +def numba_funcify_SortOp(op, node, **kwargs): + if op.kind != "quicksort": + warnings.warn( + ( + f'Numba function sort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + @numba_njit + def sort_f(a, axis): + axis = axis.item() + + a_swapped = np.swapaxes(a, axis, -1) + a_sorted = np.sort(a_swapped) + a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) + + return a_sorted_swapped + + return sort_f + + +@numba_funcify.register(ArgSortOp) +def numba_funcify_ArgSortOp(op, node, **kwargs): + kind = op.kind + + if kind not in ["quicksort", "mergesort"]: + kind = "quicksort" + warnings.warn( + ( + f'Numba function argsort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) + + @numba_njit + def argort_f(X, axis): + axis = axis.item() + + Y = np.swapaxes(X, axis, 0) + result = np.empty_like(Y, dtype="int64") + + indices = list(np.ndindex(Y.shape[1:])) + + for idx in indices: + result[(slice(None), *idx)] = np.argsort(Y[(slice(None), *idx)], kind=kind) + + result = np.swapaxes(result, 0, axis) + return result + + return argort_f diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index e877241977..5aade827cb 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -1,4 +1,11 @@ +import operator +import sys + +import numba import numpy as np +from llvmlite import ir +from numba import types +from numba.core.pythonapi import box from pytensor.graph import Type from pytensor.link.numba.dispatch import numba_funcify @@ -14,7 +21,89 @@ IncSubtensor, Subtensor, ) -from pytensor.tensor.type_other import NoneTypeT, SliceType +from pytensor.tensor.type_other import MakeSlice, NoneTypeT, SliceType + + +def slice_new(self, start, stop, step): + fnty = ir.FunctionType(self.pyobj, [self.pyobj, self.pyobj, self.pyobj]) + fn = self._get_function(fnty, name="PySlice_New") + return self.builder.call(fn, [start, stop, step]) + + +def enable_slice_boxing(): + """Enable boxing for Numba's native ``slice``s. + + TODO: this can be removed when https://github.com/numba/numba/pull/6939 is + merged and a release is made. + """ + + @box(types.SliceType) + def box_slice(typ, val, c): + """Implement boxing for ``slice`` objects in Numba. + + This makes it possible to return an Numba's internal representation of a + ``slice`` object as a proper ``slice`` to Python. + """ + start = c.builder.extract_value(val, 0) + stop = c.builder.extract_value(val, 1) + step = c.builder.extract_value(val, 2) if typ.has_step else None + + # Numba uses sys.maxsize and -sys.maxsize-1 to represent None + # We want to use None in the Python representation + none_val = ir.Constant(ir.IntType(64), sys.maxsize) + neg_none_val = ir.Constant(ir.IntType(64), -sys.maxsize - 1) + none_obj = c.pyapi.get_null_object() + + start = c.builder.select( + c.builder.icmp_signed("==", start, none_val), + none_obj, + c.box(types.int64, start), + ) + + # None stop is represented as neg_none_val when step is negative + if step is not None: + stop_none_val = c.builder.select( + c.builder.icmp_signed(">", step, ir.Constant(ir.IntType(64), 0)), + none_val, + neg_none_val, + ) + else: + stop_none_val = none_val + stop = c.builder.select( + c.builder.icmp_signed("==", stop, stop_none_val), + none_obj, + c.box(types.int64, stop), + ) + + if step is not None: + step = c.builder.select( + c.builder.icmp_signed("==", step, none_val), + none_obj, + c.box(types.int64, step), + ) + else: + step = none_obj + + slice_val = slice_new(c.pyapi, start, stop, step) + + return slice_val + + @numba.extending.overload(operator.contains) + def in_seq_empty_tuple(x, y): + if isinstance(x, types.Tuple) and not x.types: + return lambda x, y: False + + +enable_slice_boxing() + + +@numba_funcify.register(MakeSlice) +def numba_funcify_MakeSlice(op, **kwargs): + @numba_njit + def makeslice(*x): + return slice(*x) + + return makeslice @numba_funcify.register(Subtensor) diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 3a9d8767b9..c82926364e 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -3,7 +3,11 @@ import numpy as np from pytensor.link.numba.dispatch import basic as numba_basic -from pytensor.link.numba.dispatch.basic import create_tuple_string, numba_funcify +from pytensor.link.numba.dispatch.basic import ( + create_tuple_string, + numba_funcify, + numba_njit, +) from pytensor.link.utils import compile_function_src, unique_name_generator from pytensor.tensor.basic import ( Alloc, @@ -13,6 +17,7 @@ Eye, Join, MakeVector, + Nonzero, ScalarFromTensor, Split, TensorFromScalar, @@ -23,18 +28,17 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): global_env = { "np": np, - "to_scalar": numba_basic.to_scalar, "dtype": np.dtype(op.dtype), } unique_names = unique_name_generator( - ["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" + ["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" ) shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] shape_var_item_names = [f"{name}_item" for name in shape_var_names] shapes_to_items_src = indent( "\n".join( - f"{item_name} = to_scalar({shape_name})" + f"{item_name} = {shape_name}.item()" for item_name, shape_name in zip( shape_var_item_names, shape_var_names, strict=True ) @@ -58,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}): @numba_funcify.register(Alloc) def numba_funcify_Alloc(op, node, **kwargs): - global_env = {"np": np, "to_scalar": numba_basic.to_scalar} + global_env = {"np": np} unique_names = unique_name_generator( - ["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"], + ["np", "alloc", "val_np", "val", "scalar_shape", "res"], suffix_sep="_", ) shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] @@ -105,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs): @numba_basic.numba_njit(inline="always") def arange(start, stop, step): return np.arange( - numba_basic.to_scalar(start), - numba_basic.to_scalar(stop), - numba_basic.to_scalar(step), + start.item(), + stop.item(), + step.item(), dtype=dtype, ) @@ -182,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs): @numba_basic.numba_njit(inline="always") def eye(N, M, k): return np.eye( - numba_basic.to_scalar(N), - numba_basic.to_scalar(M), - numba_basic.to_scalar(k), + N.item(), + M.item(), + k.item(), dtype=dtype, ) @@ -195,16 +199,16 @@ def eye(N, M, k): def numba_funcify_MakeVector(op, node, **kwargs): dtype = np.dtype(op.dtype) - global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype} + global_env = {"np": np, "dtype": dtype} unique_names = unique_name_generator( - ["np", "to_scalar"], + ["np"], suffix_sep="_", ) input_names = [unique_names(v, force_unique=True) for v in node.inputs] def create_list_string(x): - args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else [])) + args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else [])) return f"[{args}]" makevector_def_src = f""" @@ -232,6 +236,18 @@ def tensor_from_scalar(x): def numba_funcify_ScalarFromTensor(op, **kwargs): @numba_basic.numba_njit(inline="always") def scalar_from_tensor(x): - return numba_basic.to_scalar(x) + return x.item() return scalar_from_tensor + + +@numba_funcify.register(Nonzero) +def numba_funcify_Nonzero(op, node, **kwargs): + @numba_njit + def nonzero(a): + result_tuple = np.nonzero(a) + if a.ndim == 1: + return result_tuple[0] + return list(result_tuple) + + return nonzero diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index fd9a48111f..d706f8a4fd 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -1,5 +1,4 @@ import contextlib -import inspect from collections.abc import Callable, Iterable from typing import TYPE_CHECKING, Any from unittest import mock @@ -15,7 +14,6 @@ import pytensor.scalar as ps import pytensor.tensor as pt -import pytensor.tensor.math as ptm from pytensor import config, shared from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -30,10 +28,7 @@ from pytensor.link.numba.linker import NumbaLinker from pytensor.raise_op import assert_op from pytensor.scalar.basic import ScalarOp, as_scalar -from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape -from pytensor.tensor.sort import ArgSortOp, SortOp if TYPE_CHECKING: @@ -139,65 +134,21 @@ def py_tuple_setitem(t, i, v): ll[i] = v return tuple(ll) - def py_to_scalar(x): - if isinstance(x, np.ndarray): - return x.item() - else: - return x - def njit_noop(*args, **kwargs): if len(args) == 1 and callable(args[0]): return args[0] else: return lambda x: x - def vectorize_noop(*args, **kwargs): - def wrap(fn): - # `numba.vectorize` allows an `out` positional argument. We need - # to account for that - sig = inspect.signature(fn) - nparams = len(sig.parameters) - - def inner_vec(*args): - if len(args) > nparams: - # An `out` argument has been specified for an in-place - # operation - out = args[-1] - out[...] = np.vectorize(fn)(*args[:nparams]) - return out - else: - return np.vectorize(fn)(*args) - - return inner_vec - - if len(args) == 1 and callable(args[0]): - return wrap(args[0], **kwargs) - else: - return wrap - - def py_global_numba_func(func): - if hasattr(func, "py_func"): - return func.py_func - return func - mocks = [ mock.patch("numba.njit", njit_noop), - mock.patch("numba.vectorize", vectorize_noop), - mock.patch( - "pytensor.link.numba.dispatch.basic.global_numba_func", - py_global_numba_func, - ), mock.patch( "pytensor.link.numba.dispatch.basic.tuple_setitem", py_tuple_setitem ), mock.patch("pytensor.link.numba.dispatch.basic.numba_njit", njit_noop), - mock.patch( - "pytensor.link.numba.dispatch.basic.numba_vectorize", vectorize_noop - ), mock.patch( "pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x ), - mock.patch("pytensor.link.numba.dispatch.basic.to_scalar", py_to_scalar), mock.patch( "pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", lambda dtype: dtype, @@ -370,161 +321,6 @@ def test_create_numba_signature(v, expected, force_scalar): assert res == expected -@pytest.mark.parametrize( - "x, i", - [ - (np.zeros((20, 3)), 1), - ], -) -def test_Shape(x, i): - g = Shape()(pt.as_tensor_variable(x)) - - compare_numba_and_py([], [g], []) - - g = Shape_i(i)(pt.as_tensor_variable(x)) - - compare_numba_and_py([], [g], []) - - -@pytest.mark.parametrize( - "x", - [ - [], # Empty list - [3, 2, 1], # Simple list - np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array - ], -) -@pytest.mark.parametrize("axis", [0, -1, None]) -@pytest.mark.parametrize( - ("kind", "exc"), - [ - ["quicksort", None], - ["mergesort", UserWarning], - ["heapsort", UserWarning], - ["stable", UserWarning], - ], -) -def test_Sort(x, axis, kind, exc): - if axis: - g = SortOp(kind)(pt.as_tensor_variable(x), axis) - else: - g = SortOp(kind)(pt.as_tensor_variable(x)) - - cm = contextlib.suppress() if not exc else pytest.warns(exc) - - with cm: - compare_numba_and_py([], [g], []) - - -@pytest.mark.parametrize( - "x", - [ - [], # Empty list - [3, 2, 1], # Simple list - None, # Multi-dimensional array (see below) - ], -) -@pytest.mark.parametrize("axis", [0, -1, None]) -@pytest.mark.parametrize( - ("kind", "exc"), - [ - ["quicksort", None], - ["heapsort", None], - ["stable", UserWarning], - ], -) -def test_ArgSort(x, axis, kind, exc): - if x is None: - x = np.arange(5 * 5 * 5 * 5) - np.random.shuffle(x) - x = np.reshape(x, (5, 5, 5, 5)) - - if axis: - g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) - else: - g = ArgSortOp(kind)(pt.as_tensor_variable(x)) - - cm = contextlib.suppress() if not exc else pytest.warns(exc) - - with cm: - compare_numba_and_py([], [g], []) - - -@pytest.mark.parametrize( - "v, shape, ndim", - [ - ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0), - ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2), - ( - (pt.vector(), np.arange(4, dtype=config.floatX)), - (pt.lvector(), np.array([2, 2], dtype="int64")), - 2, - ), - ], -) -def test_Reshape(v, shape, ndim): - v, v_test_value = v - shape, shape_test_value = shape - - g = Reshape(ndim)(v, shape) - inputs = [v] if not isinstance(shape, Variable) else [v, shape] - test_values = ( - [v_test_value] - if not isinstance(shape, Variable) - else [v_test_value, shape_test_value] - ) - compare_numba_and_py( - inputs, - [g], - test_values, - ) - - -def test_Reshape_scalar(): - v = pt.vector() - v_test_value = np.array([1.0], dtype=config.floatX) - g = Reshape(1)(v[0], (1,)) - - compare_numba_and_py( - [v], - g, - [v_test_value], - ) - - -@pytest.mark.parametrize( - "v, shape, fails", - [ - ( - (pt.matrix(), np.array([[1.0]], dtype=config.floatX)), - (1, 1), - False, - ), - ( - (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, 1), - True, - ), - ( - (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), - (1, None), - False, - ), - ], -) -def test_SpecifyShape(v, shape, fails): - v, v_test_value = v - g = SpecifyShape()(v, *shape) - cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) - - with cm: - compare_numba_and_py( - [v], - [g], - [v_test_value], - ) - - def test_ViewOp(): v = pt.vector() v_test_value = np.arange(4, dtype=config.floatX) @@ -602,86 +398,6 @@ def test_perform_type_convert(): compare_numba_and_py([x], out, [x_test_value]) -@pytest.mark.parametrize( - "x, y", - [ - ( - (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), - (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), - ), - ( - (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), - (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), - ), - ( - (pt.lmatrix(), rng.poisson(size=(3, 2))), - (pt.fvector(), rng.random(size=(2,)).astype("float32")), - ), - ( - (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), - ), - ( - (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), - (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), - ), - ], -) -def test_Dot(x, y): - x, x_test_value = x - y, y_test_value = y - - g = ptm.dot(x, y) - - compare_numba_and_py( - [x, y], - [g], - [x_test_value, y_test_value], - ) - - -@pytest.mark.parametrize( - "x, y, exc", - [ - ( - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - None, - ), - ( - ( - pt.dtensor3(), - rng.random(size=(2, 3, 3)).astype("float64"), - ), - ( - pt.ltensor3(), - rng.poisson(size=(2, 3, 3)).astype("int64"), - ), - None, - ), - ], -) -def test_BatchedDot(x, y, exc): - x, x_test_value = x - y, y_test_value = y - - g = blas.BatchedDot()(x, y) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - [x, y], - g, - [x_test_value, y_test_value], - ) - - def test_shared(): a = shared(np.array([1, 2, 3], dtype=config.floatX)) @@ -798,16 +514,6 @@ def test_IfElse(inputs, cond_fn, true_vals, false_vals): compare_numba_and_py(inputs, out, test_values) -@pytest.mark.xfail(reason="https://github.com/numba/numba/issues/7409") -def test_config_options_parallel(): - x = pt.dvector() - - with config.change_flags(numba__vectorize_target="parallel"): - pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) - numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] - assert numba_mul_fn.targetoptions["parallel"] is True - - def test_config_options_fastmath(): x = pt.dvector() @@ -921,32 +627,3 @@ def test_function_overhead(mode, benchmark): assert np.sum(fn(test_x)) == 1000 benchmark(fn, test_x) - - -@pytest.mark.parametrize( - "input_data", - [np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])], -) -def test_Nonzero(input_data): - a = pt.tensor("a", shape=(None,) * input_data.ndim) - - graph_outputs = pt.nonzero(a) - - compare_numba_and_py( - graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] - ) - - -@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) -def test_mat_vec_dot_performance(dtype, benchmark): - A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) - x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype) - out = ptm.dot(A, x) - - fn = function([A, x], out, mode="NUMBA", trust_input=True) - - rng = np.random.default_rng(948) - A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype) - x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) - np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) - benchmark(fn, A_test, x_test) diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 84875dac97..954656cebe 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -13,6 +13,7 @@ from pytensor.compile.ops import deep_copy_op from pytensor.gradient import grad from pytensor.scalar import Composite, float64 +from pytensor.tensor import blas, tensor from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -670,3 +671,98 @@ def test_numba_careduce_benchmark(self, axis, c_contiguous, benchmark): @pytest.mark.parametrize("c_contiguous", (True, False)) def test_dimshuffle(self, c_contiguous, benchmark): dimshuffle_benchmark("NUMBA", c_contiguous, benchmark) + + +@pytest.mark.parametrize( + "x, y", + [ + ( + (pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)), + (pt.vector(), rng.random(size=(2,)).astype(config.floatX)), + ), + ( + (pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")), + (pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")), + ), + ( + (pt.lmatrix(), rng.poisson(size=(3, 2))), + (pt.fvector(), rng.random(size=(2,)).astype("float32")), + ), + ( + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + (pt.lvector(), rng.random(size=(2,)).astype(np.int64)), + ), + ( + (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), + (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), + ), + ], +) +def test_Dot(x, y): + x, x_test_value = x + y, y_test_value = y + + g = ptm.dot(x, y) + + compare_numba_and_py( + [x, y], + [g], + [x_test_value, y_test_value], + ) + + +@pytest.mark.parametrize( + "x, y, exc", + [ + ( + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + None, + ), + ( + ( + pt.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), + ), + ( + pt.ltensor3(), + rng.poisson(size=(2, 3, 3)).astype("int64"), + ), + None, + ), + ], +) +def test_BatchedDot(x, y, exc): + x, x_test_value = x + y, y_test_value = y + + g = blas.BatchedDot()(x, y) + + cm = contextlib.suppress() if exc is None else pytest.warns(exc) + with cm: + compare_numba_and_py( + [x, y], + g, + [x_test_value, y_test_value], + ) + + +@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) +def test_mat_vec_dot_performance(dtype, benchmark): + A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype) + x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype) + out = ptm.dot(A, x) + + fn = function([A, x], out, mode="NUMBA", trust_input=True) + + rng = np.random.default_rng(948) + A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype) + x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype) + np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4) + benchmark(fn, A_test, x_test) diff --git a/tests/link/numba/test_shape.py b/tests/link/numba/test_shape.py new file mode 100644 index 0000000000..1412186cf2 --- /dev/null +++ b/tests/link/numba/test_shape.py @@ -0,0 +1,100 @@ +import contextlib + +import numpy as np +import pytest + +from pytensor import Variable, config +from pytensor import tensor as pt +from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from tests.link.numba.test_basic import compare_numba_and_py + + +@pytest.mark.parametrize( + "x, i", + [ + (np.zeros((20, 3)), 1), + ], +) +def test_Shape(x, i): + g = Shape()(pt.as_tensor_variable(x)) + + compare_numba_and_py([], [g], []) + + g = Shape_i(i)(pt.as_tensor_variable(x)) + + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize( + "v, shape, ndim", + [ + ((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0), + ((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2), + ( + (pt.vector(), np.arange(4, dtype=config.floatX)), + (pt.lvector(), np.array([2, 2], dtype="int64")), + 2, + ), + ], +) +def test_Reshape(v, shape, ndim): + v, v_test_value = v + shape, shape_test_value = shape + + g = Reshape(ndim)(v, shape) + inputs = [v] if not isinstance(shape, Variable) else [v, shape] + test_values = ( + [v_test_value] + if not isinstance(shape, Variable) + else [v_test_value, shape_test_value] + ) + compare_numba_and_py( + inputs, + [g], + test_values, + ) + + +def test_Reshape_scalar(): + v = pt.vector() + v_test_value = np.array([1.0], dtype=config.floatX) + g = Reshape(1)(v[0], (1,)) + + compare_numba_and_py( + [v], + g, + [v_test_value], + ) + + +@pytest.mark.parametrize( + "v, shape, fails", + [ + ( + (pt.matrix(), np.array([[1.0]], dtype=config.floatX)), + (1, 1), + False, + ), + ( + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, 1), + True, + ), + ( + (pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)), + (1, None), + False, + ), + ], +) +def test_SpecifyShape(v, shape, fails): + v, v_test_value = v + g = SpecifyShape()(v, *shape) + cm = contextlib.suppress() if not fails else pytest.raises(AssertionError) + + with cm: + compare_numba_and_py( + [v], + [g], + [v_test_value], + ) diff --git a/tests/link/numba/test_sort.py b/tests/link/numba/test_sort.py new file mode 100644 index 0000000000..d6c6072530 --- /dev/null +++ b/tests/link/numba/test_sort.py @@ -0,0 +1,72 @@ +import contextlib + +import numpy as np +import pytest + +from pytensor import tensor as pt +from pytensor.tensor.sort import ArgSortOp, SortOp +from tests.link.numba.test_basic import compare_numba_and_py + + +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["mergesort", UserWarning], + ["heapsort", UserWarning], + ["stable", UserWarning], + ], +) +def test_Sort(x, axis, kind, exc): + if axis: + g = SortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = SortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize( + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + None, # Multi-dimensional array (see below) + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["heapsort", None], + ["stable", UserWarning], + ], +) +def test_ArgSort(x, axis, kind, exc): + if x is None: + x = np.arange(5 * 5 * 5 * 5) + np.random.shuffle(x) + x = np.reshape(x, (5, 5, 5, 5)) + + if axis: + g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = ArgSortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: + compare_numba_and_py([], [g], []) diff --git a/tests/link/numba/test_subtensor.py b/tests/link/numba/test_subtensor.py index c9578657f2..17adb892cd 100644 --- a/tests/link/numba/test_subtensor.py +++ b/tests/link/numba/test_subtensor.py @@ -3,7 +3,9 @@ import numpy as np import pytest +import pytensor.scalar as ps import pytensor.tensor as pt +from pytensor import Mode, as_symbolic from pytensor.tensor import as_tensor from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -24,6 +26,45 @@ rng = np.random.default_rng(sum(map(ord, "Numba subtensors"))) +@pytest.mark.parametrize("step", [None, 1, 2, -2, "x"], ids=lambda x: f"step={x}") +@pytest.mark.parametrize("stop", [None, 10, "x"], ids=lambda x: f"stop={x}") +@pytest.mark.parametrize("start", [None, 0, 3, "x"], ids=lambda x: f"start={x}") +def test_slice(start, stop, step): + x = ps.int64("x") + + sym_slice = as_symbolic( + slice( + x if start == "x" else start, + x if stop == "x" else stop, + x if step == "x" else step, + ) + ) + + no_opt_mode = Mode(linker="numba", optimizer=None) + evaled_slice = sym_slice.eval({x: -5}, on_unused_input="ignore", mode=no_opt_mode) + assert isinstance(evaled_slice, slice) + if start == "x": + assert evaled_slice.start == -5 + elif start is None and (evaled_slice.step is None or evaled_slice.step > 0): + # Numba can convert to 0 (and sometimes does) in this case + assert evaled_slice.start in (None, 0) + else: + assert evaled_slice.start == start + + if stop == "x": + assert evaled_slice.stop == -5 + else: + assert evaled_slice.stop == stop + + if step == "x": + assert evaled_slice.step == -5 + elif step is None: + # Numba can convert to 1 (and sometimes does) in this case + assert evaled_slice.step in (None, 1) + else: + assert evaled_slice.step == step + + @pytest.mark.parametrize( "x, indices", [ diff --git a/tests/link/numba/test_tensor_basic.py b/tests/link/numba/test_tensor_basic.py index 625246e340..233b7bcb19 100644 --- a/tests/link/numba/test_tensor_basic.py +++ b/tests/link/numba/test_tensor_basic.py @@ -326,3 +326,17 @@ def test_Eye(n, m, k, dtype): g, [n_test, m_test] if m is not None else [n_test], ) + + +@pytest.mark.parametrize( + "input_data", + [np.array([1, 0, 3]), np.array([[0, 1], [2, 0]]), np.array([[0, 0], [0, 0]])], +) +def test_Nonzero(input_data): + a = pt.tensor("a", shape=(None,) * input_data.ndim) + + graph_outputs = pt.nonzero(a) + + compare_numba_and_py( + graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data] + )