-
Notifications
You must be signed in to change notification settings - Fork 146
Add LAPACK overloads for all variants of solve in Numba backend #1146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1146 +/- ##
==========================================
- Coverage 82.27% 82.01% -0.27%
==========================================
Files 186 187 +1
Lines 48066 48467 +401
Branches 8633 8669 +36
==========================================
+ Hits 39546 39749 +203
- Misses 6360 6554 +194
- Partials 2160 2164 +4
🚀 New features to boost your workflow:
|
fe97e5d
to
86dd9cb
Compare
This is pretty close. I just need some help with the destroy_map stuff on |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean we were not doing inplace? It was already implemented for the default backend for Solve and Cholesky: ed6ca16#diff-f5e942a523e3c0402aa63824184c381f5a963867125a21f119e2dab97a110d95
Just not in Numba yet. You shouldn't change anything in how the Ops are created, it's handled by the inplace rewrites. You may need to explicitly trigger them in the numba tests (the default numba mode in compare_py_and_numba
does not include them), but I am not sure that utility actually works with inplace stuff
edd803d
to
86c5539
Compare
ee3e337
to
17e7247
Compare
pt.linalg.solve
Missing dispatch for JAX (and possibly torch) for the LU Op? |
17e7247
to
c4a0dd2
Compare
Squash + Merge I assume? |
Also |
f200b88
to
08e3b97
Compare
Description
Goal of this PR is to give numba mode full coverage of
scipy.linalg.solve
options. We currently only supportassume_a = "gen"
. If users select a different solver, they get incorrect results (see #422 ). This PR should fix that bug, plus add:overwrite_a
in numba modeoverwrite_b
in numba modetransposed
argument (all modes)lu_factor
andlu_solve
Ops
(all modes)assume_a = "sym"
andassume_a = "pos"
in numba modecho_solve
in numba modeWe get the
lu_factor
andlu_solve
Ops kind of "for free" because I'm adding overloads fordgetrs
anddgetrf
. We just have to write the Ops and do the JAX dispatch. JVP forlu_factor
is here. Help wanted. I might decide that these Ops are out of scope for this PR and open another one.Related Issue
pt.linalg.solve
returns incorrect results whenmode = "NUMBA"
#422Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1146.org.readthedocs.build/en/1146/