-
Notifications
You must be signed in to change notification settings - Fork 149
Add LAPACK overloads for all variants of solve in Numba backend #1146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1146 +/- ##
==========================================
- Coverage 82.27% 82.01% -0.27%
==========================================
Files 186 187 +1
Lines 48066 48467 +401
Branches 8633 8669 +36
==========================================
+ Hits 39546 39749 +203
- Misses 6360 6554 +194
- Partials 2160 2164 +4
🚀 New features to boost your workflow:
|
fe97e5d to
86dd9cb
Compare
|
This is pretty close. I just need some help with the destroy_map stuff on |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean we were not doing inplace? It was already implemented for the default backend for Solve and Cholesky: ed6ca16#diff-f5e942a523e3c0402aa63824184c381f5a963867125a21f119e2dab97a110d95
Just not in Numba yet. You shouldn't change anything in how the Ops are created, it's handled by the inplace rewrites. You may need to explicitly trigger them in the numba tests (the default numba mode in compare_py_and_numba does not include them), but I am not sure that utility actually works with inplace stuff
edd803d to
86c5539
Compare
ee3e337 to
17e7247
Compare
pt.linalg.solve |
Missing dispatch for JAX (and possibly torch) for the LU Op? |
17e7247 to
c4a0dd2
Compare
|
Squash + Merge I assume? |
|
Also |
f200b88 to
08e3b97
Compare
Description
Goal of this PR is to give numba mode full coverage of
scipy.linalg.solveoptions. We currently only supportassume_a = "gen". If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:overwrite_ain numba modeoverwrite_bin numba modetransposedargument (all modes)lu_factorandlu_solveOps(all modes)assume_a = "sym"andassume_a = "pos"in numba modecho_solvein numba modeWe get the
lu_factorandlu_solveOps kind of "for free" because I'm adding overloads fordgetrsanddgetrf. We just have to write the Ops and do the JAX dispatch. JVP forlu_factoris here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.Related Issue
pt.linalg.solvereturns incorrect results whenmode = "NUMBA"#422Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/