Skip to content

Commit 398f4ad

Browse files
committed
fixed documentation
1 parent 912eab4 commit 398f4ad

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

pytensor/tensor/nlinalg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pytensor.gradient import DisconnectedType
1212
from pytensor.graph.basic import Apply
1313
from pytensor.graph.op import Op
14+
from pytensor.tensor import TensorLike
1415
from pytensor.tensor import basic as ptb
1516
from pytensor.tensor import math as ptm
1617
from pytensor.tensor.basic import as_tensor_variable, diagonal
@@ -266,7 +267,7 @@ def __str__(self):
266267
return "SLogDet"
267268

268269

269-
def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
270+
def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
270271
"""
271272
Compute the sign and (natural) logarithm of the determinant of an array.
272273
@@ -279,12 +280,11 @@ def slogdet(x: ptb.TensorVariable) -> tuple[ptb.TensorVariable, ptb.TensorVariab
279280
280281
Returns
281282
-------
282-
A namedtuple with the following attributes:
283+
A tuple with the following attributes:
283284
284285
sign : (...) tensor_like
285286
A number representing the sign of the determinant. For a real matrix,
286-
this is 1, 0, or -1. For a complex matrix, this is a complex number
287-
with absolute value 1 (i.e., it is on the unit circle), or else 0.
287+
this is 1, 0, or -1.
288288
logabsdet : (...) tensor_like
289289
The natural log of the absolute value of the determinant.
290290

tests/tensor/rewriting/test_linalg.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,7 @@ def test_slogdet_specialization():
939939
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
940940
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
941941
exp_det_x = pt.exp(det_x)
942+
942943
# REWRITE TESTS
943944
# sign(det(x))
944945
f = function([x], [sign_det_x], mode="FAST_RUN")
@@ -952,6 +953,7 @@ def test_slogdet_specialization():
952953
atol=1e-3 if config.floatX == "float32" else 1e-8,
953954
rtol=1e-3 if config.floatX == "float32" else 1e-8,
954955
)
956+
955957
# log(abs(det(x)))
956958
f = function([x], [log_abs_det_x], mode="FAST_RUN")
957959
nodes = f.maker.fgraph.apply_nodes
@@ -964,6 +966,7 @@ def test_slogdet_specialization():
964966
atol=1e-3 if config.floatX == "float32" else 1e-8,
965967
rtol=1e-3 if config.floatX == "float32" else 1e-8,
966968
)
969+
967970
# log(det(x))
968971
f = function([x], [log_det_x], mode="FAST_RUN")
969972
nodes = f.maker.fgraph.apply_nodes
@@ -976,17 +979,20 @@ def test_slogdet_specialization():
976979
atol=1e-3 if config.floatX == "float32" else 1e-8,
977980
rtol=1e-3 if config.floatX == "float32" else 1e-8,
978981
)
979-
# more than 1 valid function
982+
983+
# More than 1 valid function
980984
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
981985
nodes = f.maker.fgraph.apply_nodes
982986
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
983987
assert not any(isinstance(node.op, Det) for node in nodes)
984-
# other functions (rewrite shouldnt be applied to these)
985-
# only invalid functions
988+
989+
# Other functions (rewrite shouldnt be applied to these)
990+
# Only invalid functions
986991
f = function([x], [exp_det_x], mode="FAST_RUN")
987992
nodes = f.maker.fgraph.apply_nodes
988993
assert not any(isinstance(node.op, SLogDet) for node in nodes)
989-
# invalid + valid function
994+
995+
# Invalid + Valid function
990996
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
991997
nodes = f.maker.fgraph.apply_nodes
992998
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 commit comments

Comments
 (0)