Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 31, 2025

Split shows up in the graph of Join. It's just a fancy sequence of slice operations. np.split is just a thin wrapper that does this:

https://github.com/numpy/numpy/blob/e7a123b2d3eca9897843791dd698c1803d9a39c2/numpy/lib/_shape_base_impl.py#L789-L796

If it was not for the possible dynamic axis (which Join also supports) we wouldn't need Split at all.

We may still want it for the second order derivatives of join graphs. The gradient over Split is more clean than the eager gradient over multiple subtensors, it's just the reverse join on the output gradients.

I added a specialization rewrite that converts Split to the respective Subtensor graph. I see speedups in all backends.

The C backend wasn't really capable of returning a view of the inputs, so this optimization avoids as many copies as there are splits:

import numpy as np
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_mode

x = pt.vector("x", shape=(1000,), dtype=int)
ys = pt.split(x, (250, 250, 500), 3)
fn = pytensor.function(
    # Avoid deepcopies
    [pytensor.In(x, borrow=True)],
    [pytensor.Out(y, borrow=True) for y in ys], 
    mode=get_mode("NUMBA"), #.excluding("split_to_subtensor"), 
    trust_input=True,
)
fn.dprint(print_view_map=True)

x_test = np.arange(1000)
fn(x_test)
%timeit fn(x_test)

📚 Documentation preview 📚: https://pytensor--1334.org.readthedocs.build/en/1334/

@ricardoV94 ricardoV94 changed the title Split tweaks Specialize away Split Mar 31, 2025
@ricardoV94 ricardoV94 marked this pull request as draft April 2, 2025 17:15
@ricardoV94
Copy link
Member Author

Closed in favor of #1343

@ricardoV94 ricardoV94 closed this Apr 3, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant