Skip to content

Conversation

AllenDowney
Copy link
Contributor

@AllenDowney AllenDowney commented Jun 6, 2025

Add dot operation to xtensor module

This PR adds support for the dot product operation in the xtensor module. The implementation includes:

New dot method

  • Added a .dot() method to XTensorVariable in pytensor/xtensor/type.py to provide a consistent interface for dot operations, similar to other math functions.

Rewrite rule for dot

  • Implemented a rewrite rule in pytensor/xtensor/rewriting/math.py that converts the XDot operation to a tensor-based dot operation using tensordot. This rule handles dimension alignment and contraction correctly.

Import of math rewriting module

  • Updated pytensor/xtensor/rewriting/__init__.py to import the math rewriting module, ensuring that the dot rewrite rule is registered and available during the rewrite pass.

Unit tests

  • Added a new test function test_dot() in tests/xtensor/test_math.py to verify the basic functionality of the dot operation, including matrix-matrix and matrix-vector dot products, proper dimension handling, and shape validation.

These changes ensure that the xtensor module now supports dot operations, maintaining consistency with other math functions and enabling proper dimension handling for tensor contractions.


📚 Documentation preview 📚: https://pytensor--1448.org.readthedocs.build/en/1448/

Adding rewrite

Lint
@AllenDowney
Copy link
Contributor Author

Closing because it was based on the wrong branch

@AllenDowney AllenDowney closed this Jun 6, 2025
@AllenDowney AllenDowney deleted the add_xtensor_dot branch June 6, 2025 18:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants