Skip to content

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Dec 31, 2024

Description

Goal of this PR is to give numba mode full coverage of scipy.linalg.solve options. We currently only support assume_a = "gen". If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:

  • support for overwrite_a in numba mode
  • support for overwrite_b in numba mode
  • support for transposed argument (all modes)
  • lu_factor and lu_solve Ops (all modes)
  • support for assume_a = "sym" and assume_a = "pos" in numba mode
  • support for cho_solve in numba mode

We get the lu_factor and lu_solve Ops kind of "for free" because I'm adding overloads for dgetrs and dgetrf. We just have to write the Ops and do the JAX dispatch. JVP for lu_factor is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/

@jessegrabowski jessegrabowski added bug Something isn't working enhancement New feature or request numba SciPy compatibility linalg Linear algebra labels Dec 31, 2024
@jessegrabowski jessegrabowski marked this pull request as ready for review December 31, 2024 15:24
@jessegrabowski jessegrabowski requested review from aseyboldt and ricardoV94 and removed request for aseyboldt December 31, 2024 15:24
Copy link

codecov bot commented Dec 31, 2024

Codecov Report

Attention: Patch coverage is 53.52349% with 277 lines in your changes missing coverage. Please review.

Project coverage is 82.01%. Comparing base (4fa9bb8) to head (08e3b97).
Report is 185 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/slinalg.py 42.65% 233 Missing and 9 partials ⚠️
pytensor/link/numba/dispatch/_LAPACK.py 77.27% 29 Missing and 6 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1146      +/-   ##
==========================================
- Coverage   82.27%   82.01%   -0.27%     
==========================================
  Files         186      187       +1     
  Lines       48066    48467     +401     
  Branches     8633     8669      +36     
==========================================
+ Hits        39546    39749     +203     
- Misses       6360     6554     +194     
- Partials     2160     2164       +4     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 79.18% <ø> (-1.90%) ⬇️
pytensor/tensor/slinalg.py 93.52% <100.00%> (ø)
pytensor/link/numba/dispatch/_LAPACK.py 77.27% <77.27%> (ø)
pytensor/link/numba/dispatch/slinalg.py 44.77% <42.65%> (-8.03%) ⬇️

... and 20 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@jessegrabowski
Copy link
Member Author

This is pretty close. I just need some help with the destroy_map stuff on Solve. I guess this wasn't being before? The code is a bit hard to follow with all the subclassing.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean we were not doing inplace? It was already implemented for the default backend for Solve and Cholesky: ed6ca16#diff-f5e942a523e3c0402aa63824184c381f5a963867125a21f119e2dab97a110d95

Just not in Numba yet. You shouldn't change anything in how the Ops are created, it's handled by the inplace rewrites. You may need to explicitly trigger them in the numba tests (the default numba mode in compare_py_and_numba does not include them), but I am not sure that utility actually works with inplace stuff

@jessegrabowski jessegrabowski force-pushed the numba-solve branch 3 times, most recently from edd803d to 86c5539 Compare February 11, 2025 15:54
@jessegrabowski jessegrabowski force-pushed the numba-solve branch 2 times, most recently from ee3e337 to 17e7247 Compare February 16, 2025 16:32
@ricardoV94 ricardoV94 changed the title Add LAPACK overloads for all variants of pt.linalg.solve Add LAPACK overloads for all variants of solve in Numba backend Feb 16, 2025
@ricardoV94
Copy link
Member

Missing dispatch for JAX (and possibly torch) for the LU Op?

@ricardoV94
Copy link
Member

Squash + Merge I assume?

@ricardoV94
Copy link
Member

Also test_solve_correctness failing on float32

@ricardoV94 ricardoV94 merged commit bbe663d into pymc-devs:main Feb 17, 2025
63 of 64 checks passed
@jessegrabowski jessegrabowski deleted the numba-solve branch September 30, 2025 02:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working enhancement New feature or request linalg Linear algebra numba SciPy compatibility

Projects

None yet

Development

Successfully merging this pull request may close these issues.

BUG: pt.linalg.solve returns incorrect results when mode = "NUMBA"

2 participants