diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index e0b086e89c..f2405e3542 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -86,6 +86,7 @@ def impl(*inputs_and_core_shapes): output_bc_patterns, output_dtypes, inplace_pattern, + False, # allow_core_scalar (), # constant_inputs inputs, tuple_core_shapes, @@ -98,6 +99,7 @@ def impl(*inputs_and_core_shapes): # If the core op cannot be cached, the Blockwise wrapper cannot be cached either blockwise_key = None else: + blockwise_cache_version = 1 blockwise_key = "_".join( map( str, @@ -108,6 +110,7 @@ def impl(*inputs_and_core_shapes): blockwise_op.signature, input_bc_patterns, core_op_key, + blockwise_cache_version, ), ) ) diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index b91e4aa95a..a41de98573 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -68,7 +68,7 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): output_specs = [Out(o, borrow=False) for o in fgraph.outputs] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( - fgraph, squeeze_output=True, **kwargs + fgraph, squeeze_output=True, fgraph_name="numba_ofg", **kwargs ) if fgraph_cache_key is None: diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 228ed0979a..2308ddffdc 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -365,6 +365,7 @@ def impl(*inputs): output_bc_patterns_enc, output_dtypes_enc, inplace_pattern_enc, + True, # allow_core_scalar (), # constant_inputs inputs, core_output_shapes, # core_shapes diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 14247eb747..0ac6a301cc 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -470,6 +470,7 @@ def impl(core_shape, rng, size, *dist_params): output_bc_patterns, output_dtypes, inplace_pattern, + True, # allow_core_scalar (rng,), dist_params, (numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),), diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index b43914c75f..777b4d5a6c 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -236,7 +236,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) composite_fn, fgraph_key = numba_funcify_and_cache_key( - op.fgraph, squeeze_output=True, **kwargs + op.fgraph, squeeze_output=True, fgraph_name="numba_composite", **kwargs ) if fgraph_key is None: composite_key = None diff --git a/pytensor/link/numba/dispatch/scan.py b/pytensor/link/numba/dispatch/scan.py index 0196a14a3e..3664e4b68a 100644 --- a/pytensor/link/numba/dispatch/scan.py +++ b/pytensor/link/numba/dispatch/scan.py @@ -98,7 +98,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): output_specs = [Out(x, borrow=False) for x in fgraph.outputs] insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs) - scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key(op.fgraph) + scan_inner_func, inner_func_cache_key = numba_funcify_and_cache_key( + op.fgraph, fgraph_name="numba_scan" + ) outer_in_names_to_vars = { (f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs) diff --git a/pytensor/link/numba/dispatch/vectorize_codegen.py b/pytensor/link/numba/dispatch/vectorize_codegen.py index a4db022c47..f804c0c04c 100644 --- a/pytensor/link/numba/dispatch/vectorize_codegen.py +++ b/pytensor/link/numba/dispatch/vectorize_codegen.py @@ -82,6 +82,7 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, + allow_core_scalar, constant_inputs_types, input_types, output_core_shape_types, @@ -93,6 +94,7 @@ def _vectorized( output_bc_patterns, output_dtypes, inplace_pattern, + allow_core_scalar, constant_inputs_types, input_types, output_core_shape_types, @@ -119,6 +121,10 @@ def _vectorized( inplace_pattern = inplace_pattern.literal_value inplace_pattern = pickle.loads(base64.decodebytes(inplace_pattern.encode())) + if not isinstance(allow_core_scalar, types.Literal): + raise TypeError("allow_core_scalar must be literal.") + allow_core_scalar = allow_core_scalar.literal_value + batch_ndim = len(input_bc_patterns[0]) nin = len(constant_inputs_types) + len(input_types) nout = len(output_bc_patterns) @@ -142,8 +148,7 @@ def _vectorized( core_input_types = [] for input_type, bc_pattern in zip(input_types, input_bc_patterns, strict=True): core_ndim = input_type.ndim - len(bc_pattern) - # TODO: Reconsider this - if core_ndim == 0: + if allow_core_scalar and core_ndim == 0: core_input_type = input_type.dtype else: core_input_type = types.Array( @@ -196,7 +201,7 @@ def codegen( sig, args, ): - [_, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args + [_, _, _, _, _, _, constant_inputs, inputs, output_core_shapes, size] = args constant_inputs = cgutils.unpack_tuple(builder, constant_inputs) inputs = cgutils.unpack_tuple(builder, inputs) @@ -256,6 +261,7 @@ def codegen( output_bc_patterns_val, input_types, output_types, + core_scalar=allow_core_scalar, ) if len(outputs) == 1: @@ -429,6 +435,7 @@ def make_loop_call( output_bc: tuple[tuple[bool, ...], ...], input_types: tuple[Any, ...], output_types: tuple[Any, ...], + core_scalar: bool = True, ): safe = (False, False) @@ -486,7 +493,7 @@ def make_loop_call( idxs_bc, *safe, ) - if core_ndim == 0: + if core_scalar and core_ndim == 0: # Retrive scalar item at index val = builder.load(ptr) # val.set_metadata("alias.scope", input_scope_set) @@ -499,15 +506,19 @@ def make_loop_call( dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout ) core_array = context.make_array(core_arry_type)(context, builder) - core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:] - core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:] + core_shape = cgutils.unpack_tuple(builder, input.shape)[ + input_type.ndim - core_ndim : + ] + core_strides = cgutils.unpack_tuple(builder, input.strides)[ + input_type.ndim - core_ndim : + ] itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype)) context.populate_array( core_array, # TODO whey do we need to bitcast? data=builder.bitcast(ptr, core_array.data.type), - shape=cgutils.pack_array(builder, core_shape), - strides=cgutils.pack_array(builder, core_strides), + shape=core_shape, + strides=core_strides, itemsize=context.get_constant(types.intp, itemsize), # TODO what is meminfo about? meminfo=None, diff --git a/tests/link/numba/test_blockwise.py b/tests/link/numba/test_blockwise.py index af6f8e3ac9..d862ceeab1 100644 --- a/tests/link/numba/test_blockwise.py +++ b/tests/link/numba/test_blockwise.py @@ -2,10 +2,12 @@ import pytest from pytensor import function -from pytensor.tensor import lvector, tensor, tensor3 +from pytensor.graph import Apply +from pytensor.scalar import ScalarOp +from pytensor.tensor import TensorVariable, lvector, tensor, tensor3, vector from pytensor.tensor.basic import Alloc, ARange, constant from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.nlinalg import SVD, Det from pytensor.tensor.slinalg import Cholesky, cholesky from tests.link.numba.test_basic import compare_numba_and_py, numba_mode @@ -90,3 +92,39 @@ def test_blockwise_scalar_dimshuffle(): ) out = blockwise_scalar_ds(x) compare_numba_and_py([x], [out], [np.arange(9)], eval_obj_mode=False) + + +def test_blockwise_vs_elemwise_scalar_op(): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1760 + + class TestScalarOp(ScalarOp): + def make_node(self, x): + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, outputs): + [x] = inputs + if isinstance(node.inputs[0], TensorVariable): + assert isinstance(x, np.ndarray) + else: + assert isinstance(x, np.number | float) + out = x + 1 + if isinstance(node.outputs[0], TensorVariable): + out = np.asarray(out) + outputs[0][0] = out + + x = vector("x") + y = Elemwise(TestScalarOp())(x) + with pytest.warns( + UserWarning, + match="Numba will use object mode to run TestScalarOp's perform method", + ): + fn = function([x], y, mode="NUMBA") + np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1]) + + z = Blockwise(TestScalarOp(), signature="()->()")(x) + with pytest.warns( + UserWarning, + match="Numba will use object mode to run TestScalarOp's perform method", + ): + fn = function([x], z, mode="NUMBA") + np.testing.assert_allclose(fn(np.zeros((3,))), [1, 1, 1])