-
Notifications
You must be signed in to change notification settings - Fork 139
Rewrite solves involving kron to eliminate kron #1559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -588,9 +588,11 @@ def svd_uv_merge(fgraph, node): | |||||||||||||||||||||
@node_rewriter([Blockwise]) | ||||||||||||||||||||||
def rewrite_inv_inv(fgraph, node): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), we get back our original input without having to compute inverse once. | ||||||||||||||||||||||
This rewrite takes advantage of the fact that if there are two consecutive inverse operations (inv(inv(input))), | ||||||||||||||||||||||
we get back our original input without having to compute inverse once. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to be simply rewritten. | ||||||||||||||||||||||
Here, we check for direct inverse operations (inv/pinv) and allows for any combination of these "inverse" nodes to | ||||||||||||||||||||||
be simply rewritten. | ||||||||||||||||||||||
|
||||||||||||||||||||||
Parameters | ||||||||||||||||||||||
---------- | ||||||||||||||||||||||
|
@@ -855,6 +857,83 @@ def rewrite_det_kronecker(fgraph, node): | |||||||||||||||||||||
return [det_final] | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
@register_canonicalize("shape_unsafe") | ||||||||||||||||||||||
@register_stabilize("shape_unsafe") | ||||||||||||||||||||||
@node_rewriter([Blockwise]) | ||||||||||||||||||||||
def rewrite_solve_kron_to_solve(fgraph, node): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Given a linear system of the form: | ||||||||||||||||||||||
|
||||||||||||||||||||||
.. math:: | ||||||||||||||||||||||
|
||||||||||||||||||||||
(A \\otimes B) x = y | ||||||||||||||||||||||
|
||||||||||||||||||||||
Define :math:`\text{vec}(x)` as a column-wise raveling operation (``x.reshape(-1, order='F')`` in code). Further, | ||||||||||||||||||||||
define :math:`y = \text{vec}(Y)`. Then the above expression can be rewritten as: | ||||||||||||||||||||||
|
||||||||||||||||||||||
.. math:: | ||||||||||||||||||||||
|
||||||||||||||||||||||
x = \text{vec}(B^{-1} Y A^{-T}) | ||||||||||||||||||||||
|
||||||||||||||||||||||
Eliminating the kronecker product from the expression. | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not isinstance(node.op.core_op, SolveBase): | ||||||||||||||||||||||
return | ||||||||||||||||||||||
|
||||||||||||||||||||||
solve_op = node.op | ||||||||||||||||||||||
props_dict = solve_op.core_op._props_dict() | ||||||||||||||||||||||
b_ndim = props_dict["b_ndim"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
A, b = node.inputs | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not A.owner or not ( | ||||||||||||||||||||||
isinstance(A.owner.op, KroneckerProduct) | ||||||||||||||||||||||
or isinstance(A.owner.op, Blockwise) | ||||||||||||||||||||||
and isinstance(A.owner.op.core_op, KroneckerProduct) | ||||||||||||||||||||||
): | ||||||||||||||||||||||
Comment on lines
+890
to
+894
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably broke the parenthesis, but you get the idea. Negate the whole condition that is required and group the
Suggested change
|
||||||||||||||||||||||
return | ||||||||||||||||||||||
|
||||||||||||||||||||||
x1, x2 = A.owner.inputs | ||||||||||||||||||||||
|
||||||||||||||||||||||
# If x1 and x2 have statically known core shapes, check that they are square. If not, the rewrite will be invalid. | ||||||||||||||||||||||
# We will proceed if they are unknown, but this makes the rewrite shape unsafe. | ||||||||||||||||||||||
Comment on lines
+899
to
+900
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||||||||||||||||||
x1_core_shapes = x1.type.shape[-2:] | ||||||||||||||||||||||
x2_core_shapes = x2.type.shape[-2:] | ||||||||||||||||||||||
|
||||||||||||||||||||||
if ( | ||||||||||||||||||||||
all(shape is not None for shape in x1_core_shapes) | ||||||||||||||||||||||
and x1_core_shapes[-1] != x1_core_shapes[-2] | ||||||||||||||||||||||
) or ( | ||||||||||||||||||||||
all(shape is not None for shape in x2_core_shapes) | ||||||||||||||||||||||
and x2_core_shapes[-1] != x2_core_shapes[-2] | ||||||||||||||||||||||
): | ||||||||||||||||||||||
return None | ||||||||||||||||||||||
|
||||||||||||||||||||||
m, n = x1.shape[-2], x2.shape[-2] | ||||||||||||||||||||||
batch_shapes = x1.shape[:-2] | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. x1/x2 batch shapes could broadcast in blockwise |
||||||||||||||||||||||
|
||||||||||||||||||||||
if b_ndim == 1: | ||||||||||||||||||||||
# The rewritten expression will reshape B to be 2d. The easiest way to handle this is to just make a new | ||||||||||||||||||||||
# solve node with n_ndim = 2 | ||||||||||||||||||||||
props_dict["b_ndim"] = 2 | ||||||||||||||||||||||
new_solve_op = Blockwise(type(solve_op.core_op)(**props_dict)) | ||||||||||||||||||||||
jessegrabowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
B = b.reshape((*batch_shapes, m, n)) | ||||||||||||||||||||||
res = new_solve_op(x1, new_solve_op(x2, B.mT).mT).reshape((*batch_shapes, -1)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
else: | ||||||||||||||||||||||
# If b_ndim is 2, we need to keep track of the original right-most dimension of b as an additional | ||||||||||||||||||||||
# batch dimension | ||||||||||||||||||||||
b_batch = b.shape[-1] | ||||||||||||||||||||||
B = pt.moveaxis(b, -1, 0).reshape((b_batch, *batch_shapes, m, n)) | ||||||||||||||||||||||
|
||||||||||||||||||||||
res = pt.moveaxis(solve_op(x1, solve_op(x2, B.mT).mT), 0, -1).reshape( | ||||||||||||||||||||||
(*batch_shapes, -1, b_batch) | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
return [res] | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing copy_stack_trace |
||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
@register_canonicalize | ||||||||||||||||||||||
@register_stabilize | ||||||||||||||||||||||
@node_rewriter([Blockwise]) | ||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -828,6 +828,156 @@ def test_slogdet_kronecker_rewrite(): | |
) | ||
|
||
|
||
def count_kron_ops(fgraph): | ||
return sum( | ||
[ | ||
isinstance(node.op, KroneckerProduct) | ||
or ( | ||
isinstance(node.op, Blockwise) | ||
and isinstance(node.op.core_op, KroneckerProduct) | ||
) | ||
for node in fgraph.apply_nodes | ||
] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("add_batch", [True, False], ids=["batched", "not_batched"]) | ||
@pytest.mark.parametrize("b_ndim", [1, 2], ids=["b_ndim_1", "b_ndim_2"]) | ||
@pytest.mark.parametrize( | ||
"solve_op, solve_kwargs", | ||
[ | ||
(pt.linalg.solve, {"assume_a": "gen"}), | ||
(pt.linalg.solve, {"assume_a": "pos"}), | ||
(pt.linalg.solve, {"assume_a": "upper triangular"}), | ||
], | ||
ids=["general", "positive definite", "triangular"], | ||
) | ||
def test_rewrite_solve_kron_to_solve(add_batch, b_ndim, solve_op, solve_kwargs): | ||
# A and B have different shapes to make the test more interesting, but both need to be square matrices, otherwise | ||
# the rewrite is invalid. | ||
a_shape = (3, 3) if not add_batch else (2, 3, 3) | ||
b_shape = (2, 2) if not add_batch else (2, 2, 2) | ||
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape) | ||
|
||
m, n = a_shape[-2], b_shape[-2] | ||
y_shape = (m * n,) | ||
if b_ndim == 2: | ||
y_shape = (m * n, 3) | ||
if add_batch: | ||
y_shape = (2, *y_shape) | ||
|
||
y = pt.tensor("y", shape=y_shape) | ||
C = pt.vectorize(pt.linalg.kron, "(i,j),(k,l)->(m,n)")(A, B) | ||
|
||
x = solve_op(C, y, **solve_kwargs, b_ndim=b_ndim) | ||
|
||
fn_expected = pytensor.function( | ||
[A, B, y], x, mode=get_default_mode().excluding("rewrite_solve_kron_to_solve") | ||
) | ||
assert count_kron_ops(fn_expected.maker.fgraph) == 1 | ||
|
||
fn = pytensor.function([A, B, y], x) | ||
assert count_kron_ops(fn.maker.fgraph) == 0 | ||
|
||
rng = np.random.default_rng(sum(map(ord, "Go away Kron!"))) | ||
a_val = rng.normal(size=a_shape) | ||
b_val = rng.normal(size=b_shape) | ||
y_val = rng.normal(size=y_shape) | ||
|
||
if solve_kwargs["assume_a"] == "pos": | ||
a_val = a_val @ np.moveaxis(a_val, -2, -1) | ||
b_val = b_val @ np.moveaxis(b_val, -2, -1) | ||
elif solve_kwargs["assume_a"] == "upper triangular": | ||
a_idx = np.tril_indices(n=a_shape[-2], m=a_shape[-1], k=-1) | ||
b_idx = np.tril_indices(n=b_shape[-2], m=b_shape[-1], k=-1) | ||
|
||
if len(a_shape) > 2: | ||
a_idx = (slice(None, None), *a_idx) | ||
if len(b_shape) > 2: | ||
b_idx = (slice(None, None), *b_idx) | ||
|
||
a_val[a_idx] = 0 | ||
b_val[b_idx] = 0 | ||
|
||
a_val = a_val.astype(config.floatX) | ||
b_val = b_val.astype(config.floatX) | ||
y_val = y_val.astype(config.floatX) | ||
|
||
expected = fn_expected(a_val, b_val, y_val) | ||
result = fn(a_val, b_val, y_val) | ||
|
||
if config.floatX == "float64": | ||
tol = 1e-8 | ||
elif config.floatX == "float32" and not solve_kwargs["assume_a"] == "pos": | ||
tol = 1e-4 | ||
else: | ||
# Precision needs to be extremely low for the assume_a = pos test to pass in float32 mode. I don't have a | ||
# good theory of why. Skipping this case would also be an option. | ||
tol = 1e-2 | ||
|
||
np.testing.assert_allclose( | ||
expected, | ||
result, | ||
atol=tol, | ||
rtol=tol, | ||
) | ||
|
||
|
||
def test_rewrite_solve_kron_to_solve_not_applied(): | ||
# Check that the rewrite is not applied when the component matrices to the kron are static and not square | ||
A = pt.tensor("A", shape=(3, 2)) | ||
B = pt.tensor("B", shape=(2, 3)) | ||
C = pt.linalg.kron(A, B) | ||
|
||
y = pt.vector("y", shape=(6,)) | ||
x = pt.linalg.solve(C, y) | ||
|
||
fn = pytensor.function([A, B, y], x) | ||
|
||
assert count_kron_ops(fn.maker.fgraph) == 1 | ||
|
||
# If shapes are static, it should always be applied | ||
A = pt.tensor("A", shape=(3, None, None)) | ||
B = pt.tensor("B", shape=(3, None, None)) | ||
Comment on lines
+939
to
+941
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Back to the previous comment, is the previous There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C is valid, because C is square. The "problem" is that we can kron together two non-square matrices and end up with a square one (e.g. This is another case where we really really wish we had a tag for "square matrix", without having to commit to shapes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wiki seems to suggest Kron(A, B) is only invertible if both A and B are invertible, so you couldn't solve C in the first place if this wasn't the case? Is that correct? In that case it's fine to have the rewrite when the shapes are unknown (perhaps add a comment?). Otherwise it's not. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The theory looks right. The only issue I guess is that currently, you won't get an error if you have an "invalid" graph like: A = rng.normal(size=(4, 3))
B = rng.normal(size=(3, 4))
A_pt, B_pt = pt.dmatrices('A', 'B')
y_pt = pt.dvector('y')
C = pt.linalg.kron(A_pt, B_pt)
x = pt.linalg.solve(C, y_pt)
fn = pytensor.function([A_pt, B_pt, y_pt], x) You get a warning about numerical instability, but it gives you some numbers. Obviously these numbers are just nonsense, but it doesn't error. After the rewrite, you will get a shape error, which might be very surprising for someone who isn't providing a valid graph in the first place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Solve of C doesn't raise for "singular matrix"? |
||
C = pt.linalg.kron(A, B) | ||
y = pt.tensor("y", shape=(None,)) | ||
x = pt.linalg.solve(C, y) | ||
fn = pytensor.function([A, B, y], x) | ||
|
||
assert count_kron_ops(fn.maker.fgraph) == 0 | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"a_shape, b_shape", | ||
[((5, 5), (5, 5)), ((50, 50), (50, 50)), ((100, 100), (100, 100))], | ||
ids=["small", "medium", "large"], | ||
) | ||
@pytest.mark.parametrize("rewrite", [True, False], ids=["rewrite", "no_rewrite"]) | ||
def test_rewrite_solve_kron_to_solve_benchmark(a_shape, b_shape, rewrite, benchmark): | ||
A, B = pt.tensor("A", shape=a_shape), pt.tensor("B", shape=b_shape) | ||
C = pt.linalg.kron(A, B) | ||
|
||
m, n = a_shape[-2], b_shape[-2] | ||
has_batch = len(a_shape) == 3 | ||
y_shape = (a_shape[0], m * n) if has_batch else (m * n,) | ||
y = pt.tensor("y", shape=y_shape) | ||
x = pt.linalg.solve(C, y, b_ndim=1) | ||
|
||
rng = np.random.default_rng(sum(map(ord, "Go away Kron!"))) | ||
a_val = rng.normal(size=a_shape).astype(config.floatX) | ||
b_val = rng.normal(size=b_shape).astype(config.floatX) | ||
y_val = rng.normal(size=y_shape).astype(config.floatX) | ||
|
||
mode = ( | ||
get_default_mode() | ||
if rewrite | ||
else get_default_mode().excluding("rewrite_solve_kron_to_solve") | ||
) | ||
|
||
fn = pytensor.function([A, B, y], x, mode=mode) | ||
benchmark(fn, a_val, b_val, y_val) | ||
|
||
|
||
def test_cholesky_eye_rewrite(): | ||
x = pt.eye(10) | ||
L = pt.linalg.cholesky(x) | ||
|
Uh oh!
There was an error while loading. Please reload this page.