Skip to content

Commit 7f179af

Browse files
committed
fixed mypy error and pytorch linking test
1 parent 7a05cb1 commit 7f179af

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,29 @@ def __str__(self):
268268

269269
def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
270270
"""
271-
This function simplfies the slogdet operation into 2 separate operations using directly the det op : sign(det_val) and log(abs(det_val))
271+
Compute the sign and (natural) logarithm of the determinant of an array.
272+
273+
Returns a naive graph which is optimized later using rewrites with the det operation.
274+
275+
Parameters
276+
----------
277+
x : (..., M, M) tensor or tensor_like
278+
Input tensor, has to be square.
279+
280+
Returns
281+
-------
282+
A namedtuple with the following attributes:
283+
284+
sign : (...) tensor_like
285+
A number representing the sign of the determinant. For a real matrix,
286+
this is 1, 0, or -1. For a complex matrix, this is a complex number
287+
with absolute value 1 (i.e., it is on the unit circle), or else 0.
288+
logabsdet : (...) tensor_like
289+
The natural log of the absolute value of the determinant.
290+
291+
If the determinant is zero, then `sign` will be 0 and `logabsdet`
292+
will be -inf. In all cases, the determinant is equal to
293+
``sign * exp(logabsdet)``.
272294
"""
273295
det_val = det(x)
274296
return [ptm.sign(det_val), ptm.log(ptm.abs(det_val))]

tests/link/pytorch/test_nlinalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Sequence
2+
13
import numpy as np
24
import pytest
35

@@ -22,13 +24,13 @@ def matrix_test():
2224

2325
@pytest.mark.parametrize(
2426
"func",
25-
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
27+
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
2628
)
2729
def test_lin_alg_no_params(func, matrix_test):
2830
x, test_value = matrix_test
2931

3032
out = func(x)
31-
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
33+
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])
3234

3335
def assert_fn(x, y):
3436
np.testing.assert_allclose(x, y, rtol=1e-3)

0 commit comments

Comments
 (0)