Skip to content

Conversation

AllenDowney
Copy link
Contributor

@AllenDowney AllenDowney commented Jun 6, 2025

Add expand_dims operation for labeled tensors

This PR adds support for the expand_dims operation in PyTensor's labeled tensor system, allowing users to add new dimensions to labeled tensors with explicit dimension names.

Key Features

  • New ExpandDims operation that adds a new dimension to an XTensorVariable
  • Support for both static and symbolic dimension sizes
  • Automatic broadcasting when size > 1
  • Integration with existing tensor operations
  • Full compatibility with xarray's expand_dims behavior

Implementation Details

The implementation includes:

  1. New ExpandDims class in pytensor/xtensor/shape.py that handles:

    • Adding new dimensions with specified names
    • Support for both static and symbolic sizes
    • Shape inference and validation
  2. Rewriting rule in pytensor/xtensor/rewriting/shape.py that:

    • Converts labeled tensor operations to standard tensor operations
    • Handles broadcasting when needed
    • Validates symbolic sizes
  3. Comprehensive test suite in tests/xtensor/test_shape.py covering:

    • Basic dimension expansion
    • Static and symbolic sizes
    • Error cases and edge cases
    • Compatibility with xarray operations
    • Integration with other labeled tensor operations

Usage Example

import pytensor.tensor as pt
from pytensor.xtensor import xtensor

# Create a labeled tensor
x = xtensor("x", dims=("city",), shape=(3,))

# Add a new dimension
y = expand_dims(x, "country")  # Adds a new dimension of size 1
z = expand_dims(x, "country", size=4)  # Adds a new dimension of size 4

Testing

The implementation includes extensive tests that verify:

  • Correct behavior with various input shapes
  • Proper handling of symbolic sizes
  • Error cases (invalid dimensions, sizes, etc.)
  • Compatibility with xarray's expand_dims
  • Integration with other labeled tensor operations

📚 Documentation preview 📚: https://pytensor--1447.org.readthedocs.build/en/1447/

@AllenDowney
Copy link
Contributor Author

@ricardoV94 I think this is a ready for a look.

@ricardoV94 ricardoV94 force-pushed the labeled_tensors branch 4 times, most recently from 7da9935 to 7b8877b Compare June 6, 2025 16:09
@AllenDowney
Copy link
Contributor Author

Closing because it was based on the wrong branch

@AllenDowney AllenDowney closed this Jun 6, 2025
@AllenDowney AllenDowney deleted the add_expand_dims_new branch June 6, 2025 18:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants