diff --git a/pytensor/link/numba/dispatch/tensor_basic.py b/pytensor/link/numba/dispatch/tensor_basic.py index 8f5972c058..7daa625794 100644 --- a/pytensor/link/numba/dispatch/tensor_basic.py +++ b/pytensor/link/numba/dispatch/tensor_basic.py @@ -136,10 +136,7 @@ def join(axis, *tensors): def numba_funcify_Split(op, **kwargs): @numba_basic.numba_njit def split(tensor, axis, indices): - # Work around for https://github.com/numba/numba/issues/8257 - axis = axis % tensor.ndim - axis = numba_basic.to_scalar(axis) - return np.split(tensor, np.cumsum(indices)[:-1], axis=axis) + return np.split(tensor, np.cumsum(indices)[:-1], axis=axis.item()) return split diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 6bcb084f4e..17694c0f1e 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2201,8 +2201,26 @@ def make_node(self, x, axis, splits): raise TypeError("`axis` parameter must be an integer scalar") inputs = [x, axis, splits] - out_type = TensorType(dtype=x.dtype, shape=(None,) * x.type.ndim) - outputs = [out_type() for i in range(self.len_splits)] + + x_dtype = x.type.dtype + if isinstance(axis, Constant): + # In this case we can preserve more static shape info + static_axis = axis.data.item() + outputs = [] + x_static_shape = list(x.type.shape) + for i in range(self.len_splits): + try: + static_split_size = int(get_scalar_constant_value(splits[i])) + except NotScalarConstantError: + static_split_size = None + static_out_shape = x_static_shape.copy() + static_out_shape[static_axis] = static_split_size + outputs.append(tensor(shape=tuple(static_out_shape), dtype=x_dtype)) + else: + outputs = [ + tensor(shape=(None,) * x.type.ndim, dtype=x_dtype) + for i in range(self.len_splits) + ] return Apply(self, inputs, outputs) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 1af10e52b4..c38e08f607 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -21,6 +21,7 @@ Join, MakeVector, ScalarFromTensor, + Split, TensorFromScalar, alloc, as_tensor, @@ -616,6 +617,39 @@ def local_subtensor_remove_broadcastable_index(fgraph, node): return [node.inputs[0].dimshuffle(tuple(remain_dim))] +@register_specialize("shape_unsafe") +@node_rewriter(tracks=[Split]) +def split_to_subtensor(fgraph, node): + """Rewrite split(2)(x, 0) -> (x[:split_sizes[0]], x[split_sizes[0]:). + + This allows lifting the underlying split close to the inputs, and increases fusion opportunities. + It automatically handles unused split outputs as well + + It only works for constant axis + """ + x, axis, split_sizes = node.inputs + + n_splits = node.op.len_splits + if n_splits <= 1: + return [x] + + if not isinstance(axis, Constant): + return None + + empty_slices = (slice(None),) * int(axis.data) + ys = [] + + end = split_sizes[0] + ys.append(x[(*empty_slices, slice(None, end))]) + prev_start = end + for i in range(1, n_splits - 1): + end = prev_start + split_sizes[i] + ys.append(x[(*empty_slices, slice(prev_start, end))]) + prev_start = end + ys.append(x[(*empty_slices, slice(prev_start, None))]) + return ys + + @register_infer_shape @register_useless @register_canonicalize