Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 3, 2025

This makes Split (which shows up in the gradient of Join) much faster as it doesn't do useless copies.

I see a speedup of ~10x, obviously the comparison would scale with the ammount of copying that is now avoided

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.matrix("x", shape=(100, 200))
ys = pt.split(x, [10]*10, 10)
profile = None
fn = pytensor.function(
    [pytensor.In(x, borrow=True)], 
    [pytensor.Out(y, borrow=True) for y in ys],
    trust_input=True,
    profile=profile,
)
fn.dprint()
x_test = np.zeros((100, 200))
%timeit fn(x_test)
%timeit fn(x_test)
%timeit fn(x_test)
if profile:
    fn.profile.summary()

Also added static output shape and cleanup other methods of Split


📚 Documentation preview 📚: https://pytensor--1343.org.readthedocs.build/en/1343/

@ricardoV94 ricardoV94 force-pushed the faster_split branch 2 times, most recently from aa9b281 to 4331934 Compare April 3, 2025 11:59
@ricardoV94 ricardoV94 mentioned this pull request Apr 3, 2025
@ricardoV94 ricardoV94 changed the title Make Split C-impl return a view Propagate static output shapes in Split and avoid copy in C-impl Apr 3, 2025
Copy link

codecov bot commented Apr 3, 2025

Codecov Report

Attention: Patch coverage is 88.57143% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.01%. Comparing base (0f5da80) to head (f1d6ba0).
Report is 162 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/basic.py 88.23% 3 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1343   +/-   ##
=======================================
  Coverage   82.01%   82.01%           
=======================================
  Files         203      203           
  Lines       48798    48813   +15     
  Branches     8685     8688    +3     
=======================================
+ Hits        40022    40035   +13     
- Misses       6625     6627    +2     
  Partials     2151     2151           
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/tensor_basic.py 88.59% <100.00%> (-0.20%) ⬇️
pytensor/tensor/basic.py 91.13% <88.23%> (-0.04%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some rinky-dink feedback. I'm not qualified to comment on the C code, but I tried my best.

@ricardoV94 ricardoV94 merged commit 4e59f21 into pymc-devs:main Apr 8, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants