From 65d01e1a45c01aecb841a348f9dc7f7a8998aee5 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 7 May 2025 21:47:27 +0200 Subject: [PATCH 1/3] Don't specify zip strict kwarg in hot loops It seems to add a non-trivial 100ns --- pyproject.toml | 2 +- pytensor/compile/builders.py | 5 ++--- pytensor/compile/function/types.py | 6 ++++-- pytensor/ifelse.py | 8 ++++---- pytensor/link/basic.py | 12 ++++++------ pytensor/link/c/basic.py | 16 ++++++---------- pytensor/link/numba/dispatch/basic.py | 4 ++-- pytensor/link/numba/dispatch/cython_support.py | 5 +---- pytensor/link/numba/dispatch/extra_ops.py | 2 +- pytensor/link/numba/dispatch/slinalg.py | 2 +- pytensor/link/numba/dispatch/subtensor.py | 10 +++++----- pytensor/link/utils.py | 4 ++-- pytensor/scalar/basic.py | 4 ++-- pytensor/scalar/loop.py | 4 ++-- pytensor/tensor/basic.py | 4 ++-- pytensor/tensor/elemwise.py | 10 +++++----- pytensor/tensor/random/basic.py | 4 ++-- pytensor/tensor/random/utils.py | 13 ++++++------- pytensor/tensor/shape.py | 6 ++---- pytensor/tensor/type.py | 11 ++++------- 20 files changed, 60 insertions(+), 72 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bbb64549e5..9a7827d83e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] +select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead diff --git a/pytensor/compile/builders.py b/pytensor/compile/builders.py index a4a3d1840a..8a53ee3192 100644 --- a/pytensor/compile/builders.py +++ b/pytensor/compile/builders.py @@ -873,7 +873,6 @@ def clone(self): def perform(self, node, inputs, outputs): variables = self.fn(*inputs) - assert len(variables) == len(outputs) - # strict=False because asserted above - for output, variable in zip(outputs, variables, strict=False): + # zip strict not specified because we are in a hot loop + for output, variable in zip(outputs, variables): output[0] = variable diff --git a/pytensor/compile/function/types.py b/pytensor/compile/function/types.py index 9cc85f3d24..246354de0f 100644 --- a/pytensor/compile/function/types.py +++ b/pytensor/compile/function/types.py @@ -924,7 +924,8 @@ def __call__(self, *args, output_subset=None, **kwargs): # Reinitialize each container's 'provided' counter if trust_input: - for arg_container, arg in zip(input_storage, args, strict=False): + # zip strict not specified because we are in a hot loop + for arg_container, arg in zip(input_storage, args): arg_container.storage[0] = arg else: for arg_container in input_storage: @@ -934,7 +935,8 @@ def __call__(self, *args, output_subset=None, **kwargs): raise TypeError("Too many parameter passed to pytensor function") # Set positional arguments - for arg_container, arg in zip(input_storage, args, strict=False): + # zip strict not specified because we are in a hot loop + for arg_container, arg in zip(input_storage, args): # See discussion about None as input # https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5 if arg is None: diff --git a/pytensor/ifelse.py b/pytensor/ifelse.py index 8c07a99280..970b1bec1c 100644 --- a/pytensor/ifelse.py +++ b/pytensor/ifelse.py @@ -305,8 +305,8 @@ def thunk(): if len(ls) > 0: return ls else: - # strict=False because we are in a hot loop - for out, t in zip(outputs, input_true_branch, strict=False): + # zip strict not specified because we are in a hot loop + for out, t in zip(outputs, input_true_branch): compute_map[out][0] = 1 val = storage_map[t][0] if self.as_view: @@ -326,8 +326,8 @@ def thunk(): if len(ls) > 0: return ls else: - # strict=False because we are in a hot loop - for out, f in zip(outputs, inputs_false_branch, strict=False): + # zip strict not specified because we are in a hot loop + for out, f in zip(outputs, inputs_false_branch): compute_map[out][0] = 1 # can't view both outputs unless destroyhandler # improves diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index 9cf34983f2..9d9c8c2ae4 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -539,14 +539,14 @@ def make_thunk(self, **kwargs): def f(): for inputs in input_lists[1:]: - # strict=False because we are in a hot loop - for input1, input2 in zip(inputs0, inputs, strict=False): + # zip strict not specified because we are in a hot loop + for input1, input2 in zip(inputs0, inputs): input2.storage[0] = copy(input1.storage[0]) for x in to_reset: x[0] = None pre(self, [input.data for input in input_lists[0]], order, thunk_groups) - # strict=False because we are in a hot loop - for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)): + # zip strict not specified because we are in a hot loop + for i, (thunks, node) in enumerate(zip(thunk_groups, order)): try: wrapper(self.fgraph, i, node, *thunks) except Exception: @@ -668,8 +668,8 @@ def thunk( # since the error may come from any of them? raise_with_op(self.fgraph, output_nodes[0], thunk) - # strict=False because we are in a hot loop - for o_storage, o_val in zip(thunk_outputs, outputs, strict=False): + # zip strict not specified because we are in a hot loop + for o_storage, o_val in zip(thunk_outputs, outputs): o_storage[0] = o_val thunk.inputs = thunk_inputs diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index d509bd1d76..8d2a35b9ac 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -1988,27 +1988,23 @@ def make_thunk(self, **kwargs): ) def f(): - # strict=False because we are in a hot loop - for input1, input2 in zip(i1, i2, strict=False): + # zip strict not specified because we are in a hot loop + for input1, input2 in zip(i1, i2): # Set the inputs to be the same in both branches. # The copy is necessary in order for inplace ops not to # interfere. input2.storage[0] = copy(input1.storage[0]) - for thunk1, thunk2, node1, node2 in zip( - thunks1, thunks2, order1, order2, strict=False - ): - for output, storage in zip(node1.outputs, thunk1.outputs, strict=False): + for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2): + for output, storage in zip(node1.outputs, thunk1.outputs): if output in no_recycling: storage[0] = None - for output, storage in zip(node2.outputs, thunk2.outputs, strict=False): + for output, storage in zip(node2.outputs, thunk2.outputs): if output in no_recycling: storage[0] = None try: thunk1() thunk2() - for output1, output2 in zip( - thunk1.outputs, thunk2.outputs, strict=False - ): + for output1, output2 in zip(thunk1.outputs, thunk2.outputs): self.checker(output1, output2) except Exception: raise_with_op(fgraph, node1) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index a6a82ceebe..4938ecc42f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -312,10 +312,10 @@ def py_perform_return(inputs): else: def py_perform_return(inputs): - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop return tuple( out_type.filter(out[0]) - for out_type, out in zip(output_types, py_perform(inputs), strict=False) + for out_type, out in zip(output_types, py_perform(inputs)) ) @numba_njit diff --git a/pytensor/link/numba/dispatch/cython_support.py b/pytensor/link/numba/dispatch/cython_support.py index 8dccf98836..422e4be406 100644 --- a/pytensor/link/numba/dispatch/cython_support.py +++ b/pytensor/link/numba/dispatch/cython_support.py @@ -166,10 +166,7 @@ def __wrapper_address__(self): def __call__(self, *args, **kwargs): # no strict argument because of the JIT # TODO: check - args = [ - dtype(arg) - for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905 - ] + args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)] if self.has_pyx_skip_dispatch(): output = self._pyfunc(*args[:-1], **kwargs) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 1f0a33e595..f7700acf47 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -186,7 +186,7 @@ def ravelmultiindex(*inp): new_arr = arr.T.astype(np.float64).copy() for i, b in enumerate(new_arr): # no strict argument to this zip because numba doesn't support it - for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905 + for j, (d, v) in enumerate(zip(shape, b)): if v < 0 or v >= d: mode_fn(new_arr, i, j, v, d) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index a0361738ae..4630224f02 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -183,7 +183,7 @@ def block_diag(*arrs): r, c = 0, 0 # no strict argument because it is incompatible with numba - for arr, shape in zip(arrs, shapes): # noqa: B905 + for arr, shape in zip(arrs, shapes): rr, cc = shape out[r : r + rr, c : c + cc] = arr r += rr diff --git a/pytensor/link/numba/dispatch/subtensor.py b/pytensor/link/numba/dispatch/subtensor.py index 3328ea349c..fe0eda153e 100644 --- a/pytensor/link/numba/dispatch/subtensor.py +++ b/pytensor/link/numba/dispatch/subtensor.py @@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs): shape_aft = x_shape[after_last_axis:] out_shape = (*shape_bef, *idx_shape, *shape_aft) out_buffer = np.empty(out_shape, dtype=x.dtype) - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)] return out_buffer @@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs): y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out[(*outer, *scalar_idxs)] = y[(*outer, i)] return out @@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs): y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:]) for outer in np.ndindex(x_shape[:first_axis]): - for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905 + for i, scalar_idxs in enumerate(zip(*vec_idxs)): out[(*outer, *scalar_idxs)] += y[(*outer, i)] return out @@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): if not len(idxs) == len(vals): raise ValueError("The number of indices and values must match.") # no strict argument because incompatible with numba - for idx, val in zip(idxs, vals): # noqa: B905 + for idx, val in zip(idxs, vals): x[idx] = val return x else: @@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs): raise ValueError("The number of indices and values must match.") # no strict argument because unsupported by numba # TODO: this doesn't come up in tests - for idx, val in zip(idxs, vals): # noqa: B905 + for idx, val in zip(idxs, vals): x[idx] += val return x diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 9cbc3838dd..03c4f4eddc 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -207,8 +207,8 @@ def streamline_nice_errors_f(): for x in no_recycling: x[0] = None try: - # strict=False because we are in a hot loop - for thunk, node in zip(thunks, order, strict=False): + # zip strict not specified because we are in a hot loop + for thunk, node in zip(thunks, order): thunk() except Exception: raise_with_op(fgraph, node, thunk) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 909fc47c27..f71c7512bd 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4416,8 +4416,8 @@ def make_node(self, *inputs): def perform(self, node, inputs, output_storage): outputs = self.py_perform_fn(*inputs) - # strict=False because we are in a hot loop - for storage, out_val in zip(output_storage, outputs, strict=False): + # zip strict not specified because we are in a hot loop + for storage, out_val in zip(output_storage, outputs): storage[0] = out_val def grad(self, inputs, output_grads): diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 1023e6a127..80168fd122 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage): for i in range(n_steps): carry = inner_fn(*carry, *constant) - # strict=False because we are in a hot loop - for storage, out_val in zip(output_storage, carry, strict=False): + # zip strict not specified because we are in a hot loop + for storage, out_val in zip(output_storage, carry): storage[0] = out_val @property diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 4786b71778..bf2b66709a 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -3589,8 +3589,8 @@ def perform(self, node, inp, out): # Make sure the output is big enough out_s = [] - # strict=False because we are in a hot loop - for xdim, ydim in zip(x_s, y_s, strict=False): + # zip strict not specified because we are in a hot loop + for xdim, ydim in zip(x_s, y_s): if xdim == ydim: outdim = xdim elif xdim == 1: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a6a2f2ce4b..68c535bd0b 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -712,9 +712,9 @@ def perform(self, node, inputs, output_storage): if nout == 1: variables = [variables] - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop for i, (variable, storage, nout) in enumerate( - zip(variables, output_storage, node.outputs, strict=False) + zip(variables, output_storage, node.outputs) ): storage[0] = variable = np.asarray(variable, dtype=nout.dtype) @@ -729,11 +729,11 @@ def perform(self, node, inputs, output_storage): @staticmethod def _check_runtime_broadcast(node, inputs): - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop for dims_and_bcast in zip( *[ - zip(input.shape, sinput.type.broadcastable, strict=False) - for input, sinput in zip(inputs, node.inputs, strict=False) + zip(input.shape, sinput.type.broadcastable) + for input, sinput in zip(inputs, node.inputs) ], strict=False, ): diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 214a7bdd3d..199637f244 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size): # to `p.shape[:-1]` in the call to `vsearchsorted` below. if len(size) < (p.ndim - 1): raise ValueError("`size` is incompatible with the shape of `p`") - # strict=False because we are in a hot loop - for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False): + # zip strict not specified because we are in a hot loop + for s, ps in zip(reversed(size), reversed(p.shape[:-1])): if s == 1 and ps != 1: raise ValueError("`size` is incompatible with the shape of `p`") diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 23b4b50265..3635c67cba 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -44,8 +44,8 @@ def params_broadcast_shapes( max_fn = maximum if use_pytensor else max rev_extra_dims: list[int] = [] - # strict=False because we are in a hot loop - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False): + # zip strict not specified because we are in a hot loop + for ndim_param, param_shape in zip(ndims_params, param_shapes): # We need this in order to use `len` param_shape = tuple(param_shape) extras = tuple(param_shape[: (len(param_shape) - ndim_param)]) @@ -64,12 +64,12 @@ def max_bcast(x, y): extra_dims = tuple(reversed(rev_extra_dims)) - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop bcast_shapes = [ (extra_dims + tuple(param_shape)[-ndim_param:]) if ndim_param > 0 else extra_dims - for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False) + for ndim_param, param_shape in zip(ndims_params, param_shapes) ] return bcast_shapes @@ -127,10 +127,9 @@ def broadcast_params( ) broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop bcast_params = [ - broadcast_to_fn(param, shape) - for shape, param in zip(shapes, params, strict=False) + broadcast_to_fn(param, shape) for shape, param in zip(shapes, params) ] return bcast_params diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 5a4cfdc52a..348d356f98 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -447,10 +447,8 @@ def perform(self, node, inp, out_): raise AssertionError( f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}." ) - # strict=False because we are in a hot loop - if not all( - xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None - ): + # zip strict not specified because we are in a hot loop + if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None): raise AssertionError( f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}." ) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index d0b6b5fe0a..0474aad77b 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -261,10 +261,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: " PyTensor C code does not support that.", ) - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop if not all( ds == ts if ts is not None else True - for ds, ts in zip(data.shape, self.shape, strict=False) + for ds, ts in zip(data.shape, self.shape) ): raise TypeError( f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})" @@ -333,17 +333,14 @@ def in_same_class(self, otype): return False def is_super(self, otype): - # strict=False because we are in a hot loop + # zip strict not specified because we are in a hot loop if ( isinstance(otype, type(self)) and otype.dtype == self.dtype and otype.ndim == self.ndim # `otype` is allowed to be as or more shape-specific than `self`, # but not less - and all( - sb == ob or sb is None - for sb, ob in zip(self.shape, otype.shape, strict=False) - ) + and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape)) ): return True From dd4ecdaa1aa9eef9f9aa2640a62c8291d10e5763 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 7 May 2025 18:50:19 +0200 Subject: [PATCH 2/3] Avoid numpy broadcast_to and ndindex in hot loops --- pytensor/tensor/random/basic.py | 29 ++++++++++++----------------- pytensor/tensor/random/utils.py | 3 ++- pytensor/tensor/utils.py | 23 +++++++++++++++++++++++ tests/tensor/random/test_basic.py | 5 ++--- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 199637f244..ba6ffa8eaa 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -18,6 +18,7 @@ broadcast_params, normalize_size_param, ) +from pytensor.tensor.utils import faster_broadcast_to, faster_ndindex # Scipy.stats is considerably slow to import @@ -976,19 +977,13 @@ def __call__(self, alphas, size=None, **kwargs): @classmethod def rng_fn(cls, rng, alphas, size): if alphas.ndim > 1: - if size is None: - size = () - - size = tuple(np.atleast_1d(size)) - - if size: - alphas = np.broadcast_to(alphas, size + alphas.shape[-1:]) + if size is not None: + alphas = faster_broadcast_to(alphas, size + alphas.shape[-1:]) samples_shape = alphas.shape samples = np.empty(samples_shape) - for index in np.ndindex(*samples_shape[:-1]): + for index in faster_ndindex(samples_shape[:-1]): samples[index] = rng.dirichlet(alphas[index]) - return samples else: return rng.dirichlet(alphas, size=size) @@ -1800,11 +1795,11 @@ def rng_fn(cls, rng, n, p, size): if size is None: n, p = broadcast_params([n, p], [0, 1]) else: - n = np.broadcast_to(n, size) - p = np.broadcast_to(p, size + p.shape[-1:]) + n = faster_broadcast_to(n, size) + p = faster_broadcast_to(p, size + p.shape[-1:]) res = np.empty(p.shape, dtype=cls.dtype) - for idx in np.ndindex(p.shape[:-1]): + for idx in faster_ndindex(p.shape[:-1]): res[idx] = rng.multinomial(n[idx], p[idx]) return res else: @@ -1978,13 +1973,13 @@ def rng_fn(self, *params): p.shape[:batch_ndim], ) - a = np.broadcast_to(a, size + a.shape[batch_ndim:]) + a = faster_broadcast_to(a, size + a.shape[batch_ndim:]) if p is not None: - p = np.broadcast_to(p, size + p.shape[batch_ndim:]) + p = faster_broadcast_to(p, size + p.shape[batch_ndim:]) a_indexed_shape = a.shape[len(size) + 1 :] out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype) - for idx in np.ndindex(size): + for idx in faster_ndindex(size): out[idx] = rng.choice( a[idx], p=None if p is None else p[idx], size=core_shape, replace=False ) @@ -2097,10 +2092,10 @@ def rng_fn(self, rng, x, size): if size is None: size = x.shape[:batch_ndim] else: - x = np.broadcast_to(x, size + x.shape[batch_ndim:]) + x = faster_broadcast_to(x, size + x.shape[batch_ndim:]) out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype) - for idx in np.ndindex(size): + for idx in faster_ndindex(size): out[idx] = rng.permutation(x[idx]) return out diff --git a/pytensor/tensor/random/utils.py b/pytensor/tensor/random/utils.py index 3635c67cba..86628a81cb 100644 --- a/pytensor/tensor/random/utils.py +++ b/pytensor/tensor/random/utils.py @@ -15,6 +15,7 @@ from pytensor.tensor.math import maximum from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.type import int_dtypes +from pytensor.tensor.utils import faster_broadcast_to from pytensor.tensor.variable import TensorVariable @@ -125,7 +126,7 @@ def broadcast_params( shapes = params_broadcast_shapes( param_shapes, ndims_params, use_pytensor=use_pytensor ) - broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to + broadcast_to_fn = broadcast_to if use_pytensor else faster_broadcast_to # zip strict not specified because we are in a hot loop bcast_params = [ diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index 0ebb2e5434..3c730a3179 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -1,8 +1,10 @@ import re from collections.abc import Sequence +from itertools import product from typing import cast import numpy as np +from numpy import nditer import pytensor from pytensor.graph import FunctionGraph, Variable @@ -233,3 +235,24 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: # TODO: If axis tuple is equivalent to None, return None for more canonicalization? return cast(tuple, axis) + + +def faster_broadcast_to(x, shape): + # Stripped down core logic of `np.broadcast_to` + return nditer( + (x,), + flags=["multi_index", "zerosize_ok"], + op_flags=["readonly"], + itershape=shape, + order="C", + ).itviews[0] + + +def faster_ndindex(shape: Sequence[int]): + """Equivalent to `np.ndindex` but usually 10x faster. + + Unlike `np.ndindex`, this function expects a single sequence of integers + + https://github.com/numpy/numpy/issues/28921 + """ + return product(*(range(s) for s in shape)) diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index d7167b6a61..06af82ddf7 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -746,9 +746,8 @@ def test_mvnormal_cov_decomposition_method(method, psd): ], ) def test_dirichlet_samples(alphas, size): - def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None): - if size is None: - size = () + # FIXME: Is this just testing itself against itself? + def dirichlet_test_fn(alphas, size, random_state): return dirichlet.rng_fn(random_state, alphas, size) compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn) From e5e090c7202ba41d28d069a24db2fc8d0d93dc66 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Wed, 7 May 2025 14:28:31 +0200 Subject: [PATCH 3/3] Speedup python implementation of Blockwise --- pytensor/graph/op.py | 41 +++--- pytensor/link/c/op.py | 18 ++- pytensor/tensor/blockwise.py | 219 +++++++++++++++++++++++---------- tests/tensor/test_blockwise.py | 23 +++- 4 files changed, 212 insertions(+), 89 deletions(-) diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index b2d70d5828..3a00922c87 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -502,7 +502,7 @@ def make_py_thunk( self, node: Apply, storage_map: StorageMapType, - compute_map: ComputeMapType, + compute_map: ComputeMapType | None, no_recycling: list[Variable], debug: bool = False, ) -> ThunkType: @@ -513,25 +513,38 @@ def make_py_thunk( """ node_input_storage = [storage_map[r] for r in node.inputs] node_output_storage = [storage_map[r] for r in node.outputs] - node_compute_map = [compute_map[r] for r in node.outputs] if debug and hasattr(self, "debug_perform"): p = node.op.debug_perform else: p = node.op.perform - @is_thunk_type - def rval( - p=p, - i=node_input_storage, - o=node_output_storage, - n=node, - cm=node_compute_map, - ): - r = p(n, [x[0] for x in i], o) - for entry in cm: - entry[0] = True - return r + if compute_map is None: + + @is_thunk_type + def rval( + p=p, + i=node_input_storage, + o=node_output_storage, + n=node, + ): + return p(n, [x[0] for x in i], o) + + else: + node_compute_map = [compute_map[r] for r in node.outputs] + + @is_thunk_type + def rval( + p=p, + i=node_input_storage, + o=node_output_storage, + n=node, + cm=node_compute_map, + ): + r = p(n, [x[0] for x in i], o) + for entry in cm: + entry[0] = True + return r rval.inputs = node_input_storage rval.outputs = node_output_storage diff --git a/pytensor/link/c/op.py b/pytensor/link/c/op.py index b668f242e1..8ccfa2a9a3 100644 --- a/pytensor/link/c/op.py +++ b/pytensor/link/c/op.py @@ -39,7 +39,7 @@ def make_c_thunk( self, node: Apply, storage_map: StorageMapType, - compute_map: ComputeMapType, + compute_map: ComputeMapType | None, no_recycling: Collection[Variable], ) -> CThunkWrapperType: """Create a thunk for a C implementation. @@ -86,11 +86,17 @@ def is_f16(t): ) thunk, node_input_filters, node_output_filters = outputs - @is_cthunk_wrapper_type - def rval(): - thunk() - for o in node.outputs: - compute_map[o][0] = True + if compute_map is None: + rval = is_cthunk_wrapper_type(thunk) + + else: + cm_entries = [compute_map[o] for o in node.outputs] + + @is_cthunk_wrapper_type + def rval(thunk=thunk, cm_entries=cm_entries): + thunk() + for entry in cm_entries: + entry[0] = True rval.thunk = thunk rval.cthunk = thunk.cthunk diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index aa650cfa8e..1c2a069922 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,7 +1,8 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Any, cast import numpy as np +from numpy import broadcast_shapes, empty from pytensor import config from pytensor.compile.builders import OpFromGraph @@ -22,12 +23,111 @@ from pytensor.tensor.utils import ( _parse_gufunc_signature, broadcast_static_dim_lengths, + faster_broadcast_to, + faster_ndindex, import_func_from_string, safe_signature, ) from pytensor.tensor.variable import TensorVariable +def _vectorize_node_perform( + core_node: Apply, + batch_bcast_patterns: Sequence[tuple[bool, ...]], + batch_ndim: int, + impl: str | None, +) -> Callable: + """Creates a vectorized `perform` function for a given core node. + + Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op. + """ + + storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} + core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl) + single_in = len(core_node.inputs) == 1 + core_input_storage = [storage_map[inp] for inp in core_node.inputs] + core_output_storage = [storage_map[out] for out in core_node.outputs] + core_storage = core_input_storage + core_output_storage + + def vectorized_perform( + *args, + batch_bcast_patterns=batch_bcast_patterns, + batch_ndim=batch_ndim, + single_in=single_in, + core_thunk=core_thunk, + core_input_storage=core_input_storage, + core_output_storage=core_output_storage, + core_storage=core_storage, + ): + if single_in: + batch_shape = args[0].shape[:batch_ndim] + else: + _check_runtime_broadcast_core(args, batch_bcast_patterns, batch_ndim) + batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args)) + args = list(args) + for i, arg in enumerate(args): + if arg.shape[:batch_ndim] != batch_shape: + args[i] = faster_broadcast_to( + arg, batch_shape + arg.shape[batch_ndim:] + ) + + ndindex_iterator = faster_ndindex(batch_shape) + # Call once to get the output shapes + try: + # TODO: Pass core shape as input like BlockwiseWithCoreShape does? + index0 = next(ndindex_iterator) + except StopIteration: + raise NotImplementedError("vectorize with zero size not implemented") + else: + for core_input, arg in zip(core_input_storage, args): + core_input[0] = np.asarray(arg[index0]) + core_thunk() + outputs = tuple( + empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) + for core_output in core_output_storage + ) + for output, core_output in zip(outputs, core_output_storage): + output[index0] = core_output[0] + + for index in ndindex_iterator: + for core_input, arg in zip(core_input_storage, args): + core_input[0] = np.asarray(arg[index]) + core_thunk() + for output, core_output in zip(outputs, core_output_storage): + output[index] = core_output[0] + + # Clear storage + for core_val in core_storage: + core_val[0] = None + return outputs + + return vectorized_perform + + +def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ndim): + # strict=None because we are in a hot loop + # We zip together the dimension lengths of each input and their broadcast patterns + for dim_lengths_and_bcast in zip( + *[ + zip(input.shape[:batch_ndim], batch_bcast_pattern) + for input, batch_bcast_pattern in zip( + numerical_inputs, batch_bcast_patterns + ) + ], + ): + # If for any dimension where an entry has dim_length != 1, + # and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting. + if ( + any(d != 1 for d, _ in dim_lengths_and_bcast) + and (1, False) in dim_lengths_and_bcast + ): + raise ValueError( + "Runtime broadcasting not allowed. " + "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." + ) + + class Blockwise(Op): """Generalizes a core `Op` to work with batched dimensions. @@ -308,7 +408,7 @@ def L_op(self, inputs, outs, ograds): return rval - def _create_node_gufunc(self, node) -> None: + def _create_node_gufunc(self, node: Apply, impl) -> Callable: """Define (or retrieve) the node gufunc used in `perform`. If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly. @@ -316,83 +416,66 @@ def _create_node_gufunc(self, node) -> None: The gufunc is stored in the tag of the node. """ - gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) - - if gufunc_spec is not None: - gufunc = import_func_from_string(gufunc_spec[0]) - if gufunc is None: + batch_ndim = self.batch_ndim(node) + batch_bcast_patterns = [ + inp.type.broadcastable[:batch_ndim] for inp in node.inputs + ] + if ( + gufunc_spec := self.gufunc_spec + or getattr(self.core_op, "gufunc_spec", None) + ) is not None: + core_func = import_func_from_string(gufunc_spec[0]) + if core_func is None: raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") - else: - # Wrap core_op perform method in numpy vectorize - n_outs = len(self.outputs_sig) - core_node = self._create_dummy_core_node(node.inputs) - inner_outputs_storage = [[None] for _ in range(n_outs)] - - def core_func( - *inner_inputs, - core_node=core_node, - inner_outputs_storage=inner_outputs_storage, - ): - self.core_op.perform( - core_node, - [np.asarray(inp) for inp in inner_inputs], - inner_outputs_storage, - ) - - if n_outs == 1: - return inner_outputs_storage[0][0] - else: - return tuple(r[0] for r in inner_outputs_storage) + if len(node.outputs) == 1: + + def gufunc( + *inputs, + batch_bcast_patterns=batch_bcast_patterns, + batch_ndim=batch_ndim, + ): + _check_runtime_broadcast_core( + inputs, batch_bcast_patterns, batch_ndim + ) + return (core_func(*inputs),) + else: - gufunc = np.vectorize(core_func, signature=self.signature) + def gufunc( + *inputs, + batch_bcast_patterns=batch_bcast_patterns, + batch_ndim=batch_ndim, + ): + _check_runtime_broadcast_core( + inputs, batch_bcast_patterns, batch_ndim + ) + return core_func(*inputs) + else: + core_node = self._create_dummy_core_node(node.inputs) # type: ignore + gufunc = _vectorize_node_perform( + core_node, + batch_bcast_patterns=batch_bcast_patterns, + batch_ndim=self.batch_ndim(node), + impl=impl, + ) - node.tag.gufunc = gufunc + return gufunc def _check_runtime_broadcast(self, node, inputs): batch_ndim = self.batch_ndim(node) + batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs] + _check_runtime_broadcast_core(inputs, batch_bcast, batch_ndim) - # strict=False because we are in a hot loop - for dims_and_bcast in zip( - *[ - zip( - input.shape[:batch_ndim], - sinput.type.broadcastable[:batch_ndim], - strict=False, - ) - for input, sinput in zip(inputs, node.inputs, strict=False) - ], - strict=False, - ): - if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: - raise ValueError( - "Runtime broadcasting not allowed. " - "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" - "If broadcasting was intended, use `specify_broadcastable` on the relevant input." - ) + def prepare_node(self, node, storage_map, compute_map, impl=None): + node.tag.gufunc = self._create_node_gufunc(node, impl=impl) def perform(self, node, inputs, output_storage): - gufunc = getattr(node.tag, "gufunc", None) - - if gufunc is None: - # Cache it once per node - self._create_node_gufunc(node) + try: gufunc = node.tag.gufunc - - self._check_runtime_broadcast(node, inputs) - - res = gufunc(*inputs) - if not isinstance(res, tuple): - res = (res,) - - # strict=False because we are in a hot loop - for node_out, out_storage, r in zip( - node.outputs, output_storage, res, strict=False - ): - out_dtype = getattr(node_out, "dtype", None) - if out_dtype and out_dtype != r.dtype: - r = np.asarray(r, dtype=out_dtype) - out_storage[0] = r + except AttributeError: + gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None) + for out_storage, result in zip(output_storage, gufunc(*inputs)): + out_storage[0] = result def __str__(self): if self.name is None: diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index dc0f6b6e4e..a140e07846 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -12,10 +12,11 @@ from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node from pytensor.raise_op import assert_op -from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector +from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot +from pytensor.tensor.signal import convolve1d from pytensor.tensor.slinalg import ( Cholesky, Solve, @@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm benchmark(fn, *test_values) +def test_small_blockwise_performance(benchmark): + a = dmatrix(shape=(7, 128)) + b = dmatrix(shape=(7, 20)) + out = convolve1d(a, b, mode="valid") + fn = pytensor.function([a, b], out, trust_input=True) + assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) + + rng = np.random.default_rng(495) + a_test = rng.normal(size=a.type.shape) + b_test = rng.normal(size=b.type.shape) + np.testing.assert_allclose( + fn(a_test, b_test), + [ + np.convolve(a_test[i], b_test[i], mode="valid") + for i in range(a_test.shape[0]) + ], + ) + benchmark(fn, a_test, b_test) + + def test_cop_with_params(): matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")