Skip to content

Handles axis=None symbolically instead of within CumOp #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

Copilot
Copy link

@Copilot Copilot AI commented Aug 9, 2025

Fixes #1549

Summary

This PR refactors the cumulative sum/product operations to handle the special ravelling behavior of axis=None symbolically, making the code cleaner and more maintainable.

Problem

Previously, when cumsum(x, axis=None) or cumprod(x, axis=None) was called on a matrix, the CumOp class internally handled the flattening with special logic scattered throughout:

  • Special cases in make_node(), grad(), infer_shape(), and c_code()
  • Backend-specific axis=None handling in PyTorch, JAX, and Numba dispatchers
  • Complex conditional logic that made the code harder to maintain

Solution

The refactoring separates concerns by handling the ravelling symbolically:

Before:

cumsum(x, axis=None) → CumOp(axis=None)(x)

After:

cumsum(x, axis=None) → CumOp(axis=0)(flatten(x))

Key Changes

  1. Modified cumsum/cumprod functions: When axis=None, explicitly flatten the input first, then apply cumsum/cumprod with axis=0

  2. Simplified CumOp class:

    • Constructor now requires an integer axis (no more axis=None)
    • Removed all special axis=None handling from make_node, grad, infer_shape, c_code
    • Simplified vectorization logic
  3. Updated backend dispatchers: Removed axis=None handling from PyTorch, JAX, and Numba backends since CumOp now always receives integer axes

  4. Updated tests: Modified test cases to use the new simplified API while maintaining coverage of axis=None behavior through the public functions

Benefits

  • Separation of concerns: Flattening is explicit in the symbolic graph
  • Simpler implementation: CumOp only handles specific integer axes
  • Better maintainability: No more special cases scattered throughout the codebase
  • Clearer symbolic representation: Graph explicitly shows flatten → cumsum chain
  • Optimization friendly: Graph rewrites can now optimize flatten and cumsum operations independently

Backward Compatibility

✅ All existing user code continues to work unchanged
✅ Same computational results for all operations
✅ Same gradient behavior
✅ Same output shapes and types

The refactoring only changes the internal implementation - the public API remains identical.

Example

import pytensor.tensor as pt
from pytensor.tensor.extra_ops import cumsum

x = pt.matrix('x')

# This still works exactly the same from user perspective
result = cumsum(x, axis=None)

# But internally now creates: cumsum(flatten(x), axis=0) 
# instead of the old: CumOp(axis=None)(x)

This makes the code more maintainable while preserving all existing functionality.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@Copilot Copilot AI changed the title [WIP] Fix https://github.com/pymc-devs/pytensor/issues/1549 Handle special ravelling behavior of CumOp symbolically Aug 9, 2025
Copilot finished work on behalf of ricardoV94 August 9, 2025 03:50
@Copilot Copilot AI requested a review from ricardoV94 August 9, 2025 03:50
@ricardoV94
Copy link
Member

@copilot address review comments

Copy link
Author

Copilot AI commented Aug 9, 2025

@copilot address review comments

All review comments have been addressed in commit d3bfb15:

  1. Updated cumsum/cumprod to use .ravel() and simplified to one line
  2. Added axis normalization in helper functions and simplified negative axis handling
  3. Updated test to use self.op_class

@Copilot Copilot AI requested a review from ricardoV94 August 9, 2025 04:11
Copilot finished work on behalf of ricardoV94 August 9, 2025 04:11
… eliminate code duplication

Co-authored-by: ricardoV94 <[email protected]>
@Copilot Copilot AI requested a review from ricardoV94 August 9, 2025 04:33
Copilot finished work on behalf of ricardoV94 August 9, 2025 04:33
@ricardoV94 ricardoV94 changed the title Handle special ravelling behavior of CumOp symbolically CumOp no longer accepts axis=None Aug 9, 2025
@ricardoV94 ricardoV94 changed the title CumOp no longer accepts axis=None Handles axis=None symbolically instead of within CumOp Aug 9, 2025
@ricardoV94 ricardoV94 marked this pull request as ready for review August 9, 2025 14:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Handle special ravelling behavior of CumOp symbolically
2 participants