-
Notifications
You must be signed in to change notification settings - Fork 146
Refactor and update QR Op #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and update QR Op #1518
Conversation
5bc044c to
be949cd
Compare
71639ef to
a6d6c11
Compare
a6d6c11 to
112f6fd
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1518 +/- ##
==========================================
- Coverage 81.85% 81.45% -0.41%
==========================================
Files 230 232 +2
Lines 52522 53027 +505
Branches 9345 9422 +77
==========================================
+ Hits 42992 43192 +200
- Misses 7095 7390 +295
- Partials 2435 2445 +10
🚀 New features to boost your workflow:
|
| in_dtype = config.floatX if integer_input else dtype | ||
|
|
||
| @numba_njit(cache=False) | ||
| def qr(a): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to worry about a not being F-contiguous like other lapack/blas stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it's done in the lower level functions
Description
This PR updates the QRFull Op, adding static shape checking, infer_shape, and destroy_map. It also optimizes the perform method for the C backend, and tries to improve the gradient graph by checking static shapes (to avoid an ifelse).
I renamed it to QR, because I don't know what was Full about the old one. I also moved it from the numpy implementation to scipy, which gives us all the usual benefits (inplace, etc). I also went ahead and unpacked the scipy wrapper and used the LAPACK functions directly. This will give us better error handling (that is to say, none -- it should eventually return a matrix of NaN on failure) and some performance boost by caching workspace requirements.
Still a WIP, because it breaks everything by moving QR from nlinalg to slinalg. I thought about using this as an opportunity to finally eliminate this distinction and go to a more logical organization (linalg/decomposition/qr.py), but then decided against it for now. Needs discussion.
Related Issue
infer_shapemethod toQRFull#1511Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1518.org.readthedocs.build/en/1518/