-
Notifications
You must be signed in to change notification settings - Fork 155
Determinant of factorized matrices #1785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| match core_op: | ||
| case Cholesky(): | ||
| L = client.outputs[0] | ||
| new_det = matrix_diagonal_product(L) ** 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: Add the positive tag here.
Possibly also rewrite for log(x ** 2) -> log(x) * 2, when we know x is positive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all seemed out of scope so I didn't address it. Positive tagging isn't used systematically and I'd rather we have a plan for doing that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
|
Still missing tests and a couple other things. I wanted to first get #1786 in (which I think is still not ready either) |
|
I approved to poke you towards finishing it :) |
a9d058c to
de4e763
Compare
|
I used claude to add tests. I think it did a reasonable job. |
|
I also rebased, so if you want to keep working on it be aware of that. |
| assert_equal_computations([rewritten], [expected]) | ||
|
|
||
|
|
||
| def test_local_log_prod_to_sum_log_no_rewrite(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: this can be part of the first test above, as another parametrization where the expected graph is the same as the original one
| ] | ||
| assert len(det_nodes_no_opt) == 1 | ||
|
|
||
| fn_opt = function( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where these tests hard to pattern match with the rewrite_graph -> assert_equal_computations mold?
| assert_equal_computations([rewritten], [expected]) | ||
|
|
||
|
|
||
| def test_det_of_factorized_matrix_no_rewrite_without_abs(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same nit about being a parametrization above. I am not sure myself I prefer that way, just raising if you have a clear preference
|
|
||
| # Test graph that only has det_X | ||
| f = function([X], [det_X]) | ||
| f.dprint() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably added by me
| f.dprint() |
The old
local_det_cholrewrite is extended to cover more cases of a matrix that is factorized elsewhere, not just with Cholesky, but also LU, LUFactor, or SVD, QR (the latter two only if the sign isn't needed)A new rewrite is added for the determinant of a factorization itself. The logic is slightly different, for instance det(LUFactor) is non-sensical, and the determinant for some outputs of SVD/ QR can always be computed even if the determinant of the whole factorization cannot.
Also extended the rewrite of log(prod(x)) to sum(log(x)), which should increase the stability of many of these when we want the log determinant (or log(abs(determintant))).
Still missing tests
Closes #1679
Related to #573