Skip to content

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 21, 2025

Description

Adds a JAX dispatch for the CholeskySolve Op. Nobody ever uses this function (although it's quite nice), so nobody cared that we didn't have this. Now it matters because of the rewrites introduced in #1461. Graphs that benefit from this rewrite (basically any PyMC model with an MvNormal....) will error in JAX mode because the Op is missing.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1491.org.readthedocs.build/en/1491/

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a new JAX dispatch implementation for the CholeskySolve Op to support models that use Cholesky-based solutions in JAX mode. The key changes include:

  • Adding tests for the JAX implementation of cho_solve in tests/link/jax/test_slinalg.py.
  • Updating the cho_solve function signature and docstring in pytensor/tensor/slinalg.py.
  • Implementing and registering a new JAX dispatch function for CholeskySolve in pytensor/link/jax/dispatch/slinalg.py.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
tests/link/jax/test_slinalg.py Added new tests for the CholeskySolve dispatch functionality.
pytensor/tensor/slinalg.py Updated cho_solve function signature and documentation.
pytensor/link/jax/dispatch/slinalg.py Added JAX dispatch for the CholeskySolve Op.
Comments suppressed due to low confidence (1)

tests/link/jax/test_slinalg.py:338

  • [nitpick] The test function is named 'test_jax_chosolve', but the operator and primary function name is 'cho_solve'. For clarity and consistency, consider renaming the test to 'test_jax_cho_solve'.
@pytest.mark.parametrize("b_shape, lower", [((5,), True), ((5, 5), False)])

@ricardoV94 ricardoV94 added the enhancement New feature or request label Jun 21, 2025
Copy link

codecov bot commented Jun 21, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@d3bbc20). Learn more about missing BASE report.
⚠️ Report is 121 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1491   +/-   ##
=======================================
  Coverage        ?   82.01%           
=======================================
  Files           ?      214           
  Lines           ?    50439           
  Branches        ?     8907           
=======================================
  Hits            ?    41370           
  Misses          ?     6861           
  Partials        ?     2208           
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/slinalg.py 94.50% <100.00%> (ø)
pytensor/tensor/slinalg.py 93.18% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski jessegrabowski merged commit f72d7e5 into pymc-devs:main Jun 21, 2025
73 checks passed
@jessegrabowski jessegrabowski deleted the jax-chosolve branch September 30, 2025 02:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request jax linalg Linear algebra Op implementation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add jax dispatch for ChoSolve

2 participants