Skip to content

Commit de4e763

Browse files
Add tests
1 parent c41cdb4 commit de4e763

File tree

1 file changed

+203
-0
lines changed

1 file changed

+203
-0
lines changed

tests/tensor/linalg/test_rewriting.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import pytest
33

44
from pytensor import config, function, scan
5+
from pytensor import tensor as pt
56
from pytensor.compile.mode import get_default_mode
67
from pytensor.gradient import grad
8+
from pytensor.graph import rewrite_graph
79
from pytensor.scan.op import Scan
810
from pytensor.tensor._linalg.solve.rewriting import (
911
reuse_decomposition_multiple_solves,
@@ -14,7 +16,9 @@
1416
SolveLUFactorTridiagonal,
1517
)
1618
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
19+
from pytensor.tensor.elemwise import Elemwise
1720
from pytensor.tensor.linalg import solve
21+
from pytensor.tensor.nlinalg import Det, det
1822
from pytensor.tensor.slinalg import (
1923
Cholesky,
2024
CholeskySolve,
@@ -23,6 +27,7 @@
2327
SolveTriangular,
2428
)
2529
from pytensor.tensor.type import tensor
30+
from tests.unittest_tools import assert_equal_computations
2631

2732

2833
class DecompSolveOpCounter:
@@ -257,3 +262,201 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter):
257262
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
258263
with pytest.raises((ValueError, np.linalg.LinAlgError), match=err_msg):
259264
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)
265+
266+
267+
@pytest.mark.parametrize(
268+
"original_fn, expected_fn",
269+
[
270+
pytest.param(
271+
lambda x: pt.log(pt.prod(pt.abs(x))),
272+
lambda x: pt.sum(pt.log(pt.abs(x))),
273+
id="log_prod_abs",
274+
),
275+
pytest.param(
276+
lambda x: pt.log(pt.prod(pt.exp(x))), lambda x: pt.sum(x), id="log_prod_exp"
277+
),
278+
pytest.param(
279+
lambda x: pt.log(pt.prod(x**2)),
280+
lambda x: pt.sum(pt.log(pt.sqr(x))),
281+
id="log_prod_sqr",
282+
),
283+
pytest.param(
284+
lambda x: pt.log(pt.abs(pt.prod(x))),
285+
lambda x: pt.sum(pt.log(pt.abs(x))),
286+
id="log_abs_prod",
287+
),
288+
pytest.param(
289+
lambda x: pt.log(pt.prod(pt.abs(x), axis=0)),
290+
lambda x: pt.sum(pt.log(pt.abs(x)), axis=0),
291+
id="log_prod_abs_axis0",
292+
),
293+
pytest.param(
294+
lambda x: pt.log(pt.prod(pt.exp(x), axis=-1)),
295+
lambda x: pt.sum(x, axis=-1),
296+
id="log_prod_exp_axis-1",
297+
),
298+
],
299+
)
300+
def test_local_log_prod_to_sum_log(original_fn, expected_fn):
301+
x = pt.tensor("x", shape=(3, 4))
302+
out = original_fn(x)
303+
expected = expected_fn(x)
304+
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
305+
assert_equal_computations([rewritten], [expected])
306+
307+
308+
def test_local_log_prod_to_sum_log_positive_tag():
309+
x = pt.tensor("x", shape=(3, 4))
310+
x.tag.positive = True
311+
out = pt.log(pt.prod(x))
312+
expected = pt.sum(pt.log(x))
313+
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
314+
assert_equal_computations([rewritten], [expected])
315+
316+
317+
def test_local_log_prod_to_sum_log_no_rewrite():
318+
x = pt.tensor("x", shape=(3, 4))
319+
out = pt.log(pt.prod(x))
320+
rewritten = rewrite_graph(out)
321+
from pytensor.scalar.basic import Log
322+
323+
assert rewritten.owner is not None
324+
assert isinstance(rewritten.owner.op.scalar_op, Log)
325+
326+
327+
@pytest.mark.parametrize(
328+
"decomp_fn, decomp_output_idx",
329+
[
330+
pytest.param(lambda x: pt.linalg.cholesky(x), 0, id="cholesky"),
331+
pytest.param(lambda x: pt.linalg.lu(x), -1, id="lu"),
332+
pytest.param(lambda x: pt.linalg.lu_factor(x), 0, id="lu_factor"),
333+
],
334+
)
335+
def test_det_of_matrix_factorized_elsewhere(decomp_fn, decomp_output_idx):
336+
x = pt.tensor("x", shape=(3, 3))
337+
338+
decomp_out = decomp_fn(x)
339+
if isinstance(decomp_out, list):
340+
decomp_var = decomp_out[decomp_output_idx]
341+
else:
342+
decomp_var = decomp_out
343+
344+
d = det(x)
345+
346+
outputs = [decomp_var, d]
347+
fn = function([x], outputs, mode=get_default_mode())
348+
349+
det_nodes = [
350+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Det)
351+
]
352+
assert len(det_nodes) == 0
353+
354+
355+
@pytest.mark.parametrize(
356+
"decomp_fn, abs_needed",
357+
[
358+
pytest.param(lambda x: pt.linalg.svd(x, compute_uv=True), True, id="svd"),
359+
pytest.param(
360+
lambda x: pt.linalg.svd(x, compute_uv=False), True, id="svd_no_uv"
361+
),
362+
pytest.param(lambda x: pt.linalg.qr(x), True, id="qr"),
363+
],
364+
)
365+
def test_det_of_matrix_factorized_elsewhere_abs(decomp_fn, abs_needed):
366+
x = pt.tensor("x", shape=(3, 3))
367+
368+
decomp_out = decomp_fn(x)
369+
if isinstance(decomp_out, list):
370+
decomp_var = decomp_out[0]
371+
else:
372+
decomp_var = decomp_out
373+
374+
d = pt.abs(det(x))
375+
376+
outputs = [decomp_var, d]
377+
fn = function([x], outputs, mode=get_default_mode())
378+
379+
det_nodes = [
380+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Det)
381+
]
382+
assert len(det_nodes) == 0
383+
384+
385+
@pytest.mark.parametrize(
386+
"original_fn, expected_fn",
387+
[
388+
pytest.param(
389+
lambda x: det(pt.linalg.cholesky(x)),
390+
lambda x: pt.prod(
391+
pt.diagonal(pt.linalg.cholesky(x), axis1=-2, axis2=-1), axis=-1
392+
),
393+
id="det_cholesky",
394+
),
395+
pytest.param(
396+
lambda x: det(pt.linalg.lu(x)[-1]),
397+
lambda x: pt.prod(
398+
pt.diagonal(pt.linalg.lu(x)[-1], axis1=-2, axis2=-1), axis=-1
399+
),
400+
id="det_lu_U",
401+
),
402+
pytest.param(
403+
lambda x: det(pt.linalg.lu(x)[-2]),
404+
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
405+
id="det_lu_L",
406+
),
407+
],
408+
)
409+
def test_det_of_factorized_matrix(original_fn, expected_fn):
410+
x = pt.tensor("x", shape=(3, 3))
411+
out = original_fn(x)
412+
expected = expected_fn(x)
413+
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
414+
assert_equal_computations([rewritten], [expected])
415+
416+
417+
@pytest.mark.parametrize(
418+
"original_fn, expected_fn",
419+
[
420+
pytest.param(
421+
lambda x: pt.abs(det(pt.linalg.svd(x, compute_uv=True)[0])),
422+
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
423+
id="abs_det_svd_U",
424+
),
425+
pytest.param(
426+
lambda x: pt.abs(det(pt.linalg.svd(x, compute_uv=True)[2])),
427+
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
428+
id="abs_det_svd_Vt",
429+
),
430+
pytest.param(
431+
lambda x: pt.abs(det(pt.linalg.qr(x)[0])),
432+
lambda x: pt.as_tensor(1.0, dtype=x.dtype),
433+
id="abs_det_qr_Q",
434+
),
435+
pytest.param(
436+
lambda x: det(pt.linalg.qr(x)[1]),
437+
lambda x: pt.prod(
438+
pt.diagonal(pt.linalg.qr(x)[1], axis1=-2, axis2=-1), axis=-1
439+
),
440+
id="det_qr_R",
441+
),
442+
],
443+
)
444+
def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn):
445+
x = pt.tensor("x", shape=(3, 3))
446+
out = original_fn(x)
447+
expected = expected_fn(x)
448+
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
449+
assert_equal_computations([rewritten], [expected])
450+
451+
452+
def test_det_of_factorized_matrix_no_rewrite_without_abs():
453+
x = pt.tensor("x", shape=(3, 3))
454+
Q = pt.linalg.qr(x)[0]
455+
out = det(Q)
456+
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
457+
458+
assert not (
459+
rewritten.owner is not None
460+
and isinstance(rewritten.owner.op, Elemwise)
461+
and len(rewritten.owner.inputs) == 0
462+
), "det(Q) should not be rewritten to a constant without abs()"

0 commit comments

Comments
 (0)